R/gsaot_indices.R

Defines functions plot.gsaot_indices print.gsaot_indices gsaot_indices

Documented in plot.gsaot_indices print.gsaot_indices

# Add .data to the global variables in order to avoid NOTEs from R CMD check
utils::globalVariables(c(".data"))

# gsaot_indices object constructor
gsaot_indices <- function(method,
                          indices,
                          bound,
                          IS,
                          partitions,
                          x, y,
                          solver_optns = NULL,
                          Adv = NULL,
                          Diff = NULL,
                          indices_ci = NULL,
                          bound_ci = NULL,
                          IS_ci = NULL,
                          R = NULL,
                          type = NULL,
                          conf = NULL) {
  value <- list(method = method,
                indices = indices,
                bound = bound,
                x = x, y = y,
                inner_statistics = IS,
                partitions = partitions,
                boot = FALSE)

  if (!is.null(solver_optns))
    value[["solver_optns"]] = solver_optns

  if (!is.null(Adv)) {
    value[["adv"]] <- Adv
    value[["diff"]] <- Diff
  }

  if (!is.null(indices_ci)) {
    value[["boot"]] <- TRUE
    value[["indices_ci"]] <- indices_ci
    value[["inner_statistics_ci"]] <- IS_ci
    value[["bound_ci"]] <- bound_ci
    value[["R"]] <- R
    value[["type"]] <- type
    value[["conf"]] <- conf
  }

  attr(value, "class") <- "gsaot_indices"

  return(value)
}

#' Print Optimal Transport Sensitivity indices information
#'
#' @param x An object generated by \code{\link{ot_indices}},
#'   \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param data Logical, indicating whether or not the input and output data
#'   should be printed.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return The information contained in argument `x`
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices(x, y, M)
#' print(sensitivity_indices)
#'
print.gsaot_indices <- function(x, data = FALSE, ...) {
  cat("Method:", x$method, "\n")
  if (exists("solver_optns", where = x)) {
    cat("\nSolver Options:\n")
    print(x$solver_optns)
  }
  cat("\nIndices:\n")
  print(x$indices)

  if (data) {
    cat("Data:\n")
    print(x$x)
    print(x$y)
  }
  if (exists("adv", where = x)) {
    cat("\nAdvective component:\n")
    print(x$adv)
    cat("\nDiffusive component:\n")
    print(x$diff)
  }
  if (x$boot) {
    cat("\nType of confidence interval:", x$type, "\n")
    cat("Number of replicates:", x$R, "\n")
    cat("Confidence level:", x$conf, "\n")
    cat("Indices confidence intervals:\n")
    print(x$indices_ci)

    cat("\nUpper bound:", mean(x$bound), "\n")
  }
  else
    cat("\nUpper bound:", x$bound, "\n")
}

#' Plot Optimal Transport sensitivity indices
#'
#' Plot Optimal Transport based sensitivity indices using `ggplot2` package.
#'
#' @param x An object generated by \code{\link{ot_indices}},
#'   \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param ranking An integer with absolute value less or equal than the number
#'   of inputs. If positive, select the first `ranking` inputs per importance.
#'   If negative, select the last `ranking` inputs per importance.
#' @param wb_all (default `FALSE`) Logical that defines whether or not to plot
#'   the Advective and Diffusive components of the Wasserstein-Bures indices
#' @param dummy (default `NULL`) A double or and object of class `gsaot_indices`
#'   that represents a lower bound.
#' @param ... Further arguments passed to or from other methods.
#'
#'
#' @returns A \code{ggplot} object that, if called, will print
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices_wb(x, y, M)
#' sensitivity_indices
#'
#' plot(sensitivity_indices)
#'
plot.gsaot_indices <- function(x,
                               ranking = NULL,
                               wb_all = FALSE,
                               dummy = NULL, ...) {
  # If ranking is defined, plot only the selected inputs
  K <- nrow(x$indices)

  # Select only the inputs requested by `ranking`
  if (!is.null(ranking)) {
    # Check if ranking is an integer less than the number of inputs
    if (ranking %% 1 == 0 & abs(ranking) <= K) {
      inputs_to_plot <-
        ifelse(rep(sign(ranking), each = abs(ranking)) > 0,
               seq(ranking),
               seq(from = K + ranking + 1, to = K))
    } else
      stop("`ranking` should be an integer with absolute value less than the number of inputs")
  } else
    inputs_to_plot <- seq(K)

  # Find the threshold to be plotted if dummy is defined
  if (!inherits(dummy, "gsaot_indices") & !is.double(dummy) & !is.null(dummy))
    stop("`dummy` should be an object of class `gsaot_indices` or a double")

  if (!is.null(dummy) & !is.double(dummy))
    dummy <- dummy$indices[1]

  # If the indices are not from ot_indices_wb, print only the indices
  # Otherwise, plot the indices, the advective component and the diffusive one
  if (!exists("adv", where = x) | !wb_all) {
    # Create a data.frame to store all the indices
    x_indices <-
      data.frame(
        Inputs = names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
        Component = x$method,
        Indices = unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot]
      )
    x_indices$Inputs <- factor(x_indices$Inputs,
                               levels = unique(x_indices$Inputs))

    # Plot the indices ordered by magnitude
    p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
                                                   y = .data[["Indices"]],
                                                   fill = .data[["Component"]])) +
      ggplot2::geom_bar(stat = "identity") +
      ggplot2::labs(
        title = paste("Indices computed using", x$method, "solver"),
        x = "Inputs",
        y = "Indices"
      ) +
      ggplot2::guides(fill = "none")

    if (x$boot) {
      ci_data <- merge(x$indices_ci, x_indices, by = "Inputs")
      # Remove unnecessary components from the data to plot
      if (exists("Index", where = ci_data))
        ci_data <- ci_data[!(ci_data$Index %in% c("Advective", "Diffusive")), ]
      p <- p +
        ggplot2::geom_errorbar(data = ci_data,
                               ggplot2::aes(ymin = .data[["low.ci"]],
                                            ymax = .data[["high.ci"]]),
                               position = ggplot2::position_dodge2(padding = 0.2),
                               width = .5)
    }
  } else {
    # Create a data.frame to store all the indices
    x_indices <-data.frame(Inputs = rep(names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot], times = 3),
                           Component = rep(c("WB", "Advective", "Diffusive"), each = length(inputs_to_plot)),
                           Indices = c(unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
                                       unname(x$adv[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
                                       unname(x$diff[order(x$indices, decreasing = TRUE)])[inputs_to_plot]))
    x_indices$Inputs <- factor(x_indices$Inputs,
                               levels = unique(x_indices$Inputs))
    x_indices$Component <- factor(x_indices$Component,
                                  levels = unique(x_indices$Component))

    p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
                                                   y = .data[["Indices"]],
                                                   fill = .data[["Component"]])) +
      ggplot2::geom_bar(
        stat = "identity",
        position = ggplot2::position_dodge2(padding = 0.2),
        width = .5
      ) +
      ggplot2::labs(
        title = paste("Indices computed using", x$method, "solver"),
        x = "Inputs",
        y = "Indices"
      )

    if (x$boot) {
      ci_data <- merge(x_indices, x$indices_ci,
                       by.x = c("Inputs", "Component"),
                       by.y = c("Inputs", "Index"))
      p <- p +
        ggplot2::geom_errorbar(data = ci_data,
                               ggplot2::aes(ymin = .data[["low.ci"]],
                                            ymax = .data[["high.ci"]]),
                               position = ggplot2::position_dodge2(padding = 0.5),
                               width = .5)
    }
  }

  if (!is.null(dummy))
    p <- p +
      ggplot2::geom_hline(yintercept = dummy, linetype = 2, color = "red")

  return(p)
}

Try the gsaot package in your browser

Any scripts or data that you put into this service are public.

gsaot documentation built on April 3, 2025, 8:55 p.m.