R/gsaot_indices.R

Defines functions confint.gsaot_indices summary.gsaot_indices plot.gsaot_indices print.gsaot_indices gsaot_indices

Documented in confint.gsaot_indices plot.gsaot_indices print.gsaot_indices summary.gsaot_indices

# Add .data to the global variables in order to avoid NOTEs from R CMD check
utils::globalVariables(c(".data"))

# gsaot_indices object constructor
gsaot_indices <- function(method,
                          indices,
                          bound,
                          IS,
                          partitions,
                          x, y,
                          solver_optns = NULL,
                          Adv = NULL,
                          Diff = NULL,
                          indices_ci = NULL,
                          bound_ci = NULL,
                          IS_ci = NULL,
                          R = NULL,
                          type = NULL,
                          conf = NULL,
                          W_boot = NULL) {
  value <- list(method = method,
                indices = indices,
                bound = bound,
                x = x, y = y,
                separation_measures = IS,
                partitions = partitions,
                boot = FALSE)

  if (!is.null(solver_optns))
    value[["solver_optns"]] = solver_optns

  if (!is.null(Adv)) {
    value[["adv"]] <- Adv
    value[["diff"]] <- Diff
  }

  if (!is.null(indices_ci)) {
    value[["boot"]] <- TRUE
    value[["indices_ci"]] <- indices_ci
    value[["separation_measures_ci"]] <- IS_ci
    value[["bound_ci"]] <- bound_ci
    value[["R"]] <- R
    value[["type"]] <- type
    value[["conf"]] <- conf
    value[["W_boot"]] <- W_boot
  }

  attr(value, "class") <- "gsaot_indices"

  return(value)
}

#' Print optimal transport sensitivity indices information
#'
#' @param x An object generated by \code{\link{ot_indices}},
#'   \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return The information contained in argument `x`
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices(x, y, M)
#' print(sensitivity_indices)
#'
print.gsaot_indices <- function(x, ...) {
  cat("Method:", x$method, "\n")
  cat("\nIndices:\n")
  print(x$indices)

  if (exists("adv", where = x)) {
    cat("\nAdvective component:\n")
    print(x$adv)
    cat("\nDiffusive component:\n")
    print(x$diff)
  }
  if (x$boot) {
    cat("\nType of confidence interval:", x$type, "\n")
    cat("Number of replicates:", x$R, "\n")
    cat("Confidence level:", x$conf, "\n")
    cat("Bootstrap statistics:\n")
    print(x$indices_ci)
  }
}

#' Plot optimal transport sensitivity indices
#'
#' Plot Optimal Transport based sensitivity indices using `ggplot2` package.
#'
#' @param x An object generated by \code{\link{ot_indices}},
#'   \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param ranking An integer with absolute value less or equal than the number
#'   of inputs. If positive, select the first `ranking` inputs per importance.
#'   If negative, select the last `ranking` inputs per importance.
#' @param wb_all (default `FALSE`) Logical that defines whether or not to plot
#'   the Advective and Diffusive components of the Wasserstein-Bures indices.
#' @param threshold (default `NULL`) A double or and object of class `gsaot_indices`
#'   that represents a lower threshold.
#' @param ... Further arguments passed to or from other methods.
#'
#'
#' @returns A \code{ggplot} object that, if called, will print
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices_wb(x, y, M)
#' sensitivity_indices
#'
#' plot(sensitivity_indices)
#'
plot.gsaot_indices <- function(x,
                               ranking = NULL,
                               wb_all = FALSE,
                               threshold = NULL, ...) {
  # If ranking is defined, plot only the selected inputs
  K <- nrow(x$indices)

  # Select only the inputs requested by `ranking`
  if (!is.null(ranking)) {
    # Check if ranking is an integer less than the number of inputs
    if (ranking %% 1 == 0 & abs(ranking) <= K) {
      inputs_to_plot <-
        ifelse(rep(sign(ranking), each = abs(ranking)) > 0,
               seq(ranking),
               seq(from = K + ranking + 1, to = K))
    } else
      stop("`ranking` should be an integer with absolute value less than the number of inputs")
  } else
    inputs_to_plot <- seq(K)

  # Find the threshold to be plotted if threshold is defined
  if (!inherits(threshold, "gsaot_indices") & !is.double(threshold) & !is.null(threshold))
    stop("`threshold` should be an object of class `gsaot_indices` or a double")

  if (!is.null(threshold) & !is.double(threshold))
    threshold <- threshold$indices[1]

  # If the indices are not from ot_indices_wb, print only the indices
  # Otherwise, plot the indices, the advective component and the diffusive one
  if (!exists("adv", where = x) | !wb_all) {
    # Create a data.frame to store all the indices
    x_indices <-
      data.frame(
        Inputs = names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
        Component = x$method,
        Indices = unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot]
      )
    x_indices$Inputs <- factor(x_indices$Inputs,
                               levels = unique(x_indices$Inputs))

    # Plot the indices ordered by magnitude
    p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
                                                   y = .data[["Indices"]],
                                                   fill = .data[["Component"]])) +
      ggplot2::geom_bar(stat = "identity") +
      ggplot2::labs(
        title = paste("Indices computed using", x$method, "solver"),
        x = "Inputs",
        y = "Indices"
      ) +
      ggplot2::guides(fill = "none")

    if (x$boot) {
      ci_data <- merge(x_indices, x$indices_ci,
                       by.x = "Inputs", by.y = "input")
      # Remove unnecessary components from the data to plot
      if (exists("component", where = ci_data))
        ci_data <- ci_data[!(ci_data$component %in% c("advective", "diffusive")), ]
      p <- p +
        ggplot2::geom_errorbar(data = ci_data,
                               ggplot2::aes(ymin = .data[["low.ci"]],
                                            ymax = .data[["high.ci"]]),
                               position = ggplot2::position_dodge2(padding = 0.2),
                               width = .5)
    }
  } else {
    # Create a data.frame to store all the indices
    x_indices <- data.frame(Inputs = rep(names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot], times = 3),
                           Component = rep(c("wass-bures", "advective", "diffusive"), each = length(inputs_to_plot)),
                           Indices = c(unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
                                       unname(x$adv[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
                                       unname(x$diff[order(x$indices, decreasing = TRUE)])[inputs_to_plot]))
    x_indices$Inputs <- factor(x_indices$Inputs,
                               levels = unique(x_indices$Inputs))
    x_indices$Component <- factor(x_indices$Component,
                                  levels = unique(x_indices$Component))

    p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
                                                   y = .data[["Indices"]],
                                                   fill = .data[["Component"]])) +
      ggplot2::geom_bar(
        stat = "identity",
        position = ggplot2::position_dodge2(padding = 0.2),
        width = .5
      ) +
      ggplot2::labs(
        title = paste("Indices computed using", x$method, "solver"),
        x = "Inputs",
        y = "Indices"
      )

    if (x$boot) {
      ci_data <- merge(x_indices, x$indices_ci,
                       by.x = c("Inputs", "Component"),
                       by.y = c("input", "component"))
      p <- p +
        ggplot2::geom_errorbar(data = ci_data,
                               ggplot2::aes(ymin = .data[["low.ci"]],
                                            ymax = .data[["high.ci"]]),
                               position = ggplot2::position_dodge2(padding = 0.5),
                               width = .5)
    }
  }

  if (!is.null(threshold))
    p <- p +
      ggplot2::geom_hline(yintercept = threshold, linetype = 2)

  return(p)
}

#' Summary method for `gsaot_indices` objects
#'
#' @param object  An object of class \code{"gsaot_indices"}.
#' @param digits  (default: \code{3}) Number of significant digits to print for
#'  numeric values.
#' @param ranking An integer with absolute value less or equal than the number
#'   of inputs. If positive, select the first `ranking` inputs per importance.
#' @param ...     Further arguments (currently ignored).
#'
#' @return (Invisibly) a named \code{list} containing the main elements
#'         summarised on screen.
#' @export
summary.gsaot_indices <- function(object, digits = 3, ranking = NULL, ...) {

  ## ---- basic meta‑information --------------------------------------------
  method <- object$method
  if (method == "transport") {
    method <- ifelse(!is.null(object$solver_optns$method),
                     object$solver_optns$method,
                     "networksimplex")
  }
  n_inputs <- length(object$indices)

  has_wb    <- all(c("adv", "diff") %in% names(object))
  has_boot  <- isTRUE(object$boot)

  ## ---- build a data frame of indices -------------------------------------
  ord <- order(object$indices, decreasing = TRUE)
  df  <- data.frame(
    input  = names(object$indices)[ord],
    index  = signif(object$indices[ord], digits)
  )

  if (has_wb) {
    df <- data.frame(input = rep(names(object$indices)[ord], times = 3),
                     component = rep(c("wass-bures", "advective", "diffusive"), each = n_inputs),
                     index = signif(c(unname(object$indices)[ord],
                               unname(object$adv)[ord],
                               unname(object$diff)[ord]), digits))
  }

  if (has_boot) {
    ci  <- object$indices_ci
    ci$original <- NULL
    ci$bias <- NULL
    ci$low.ci <- signif(ci$low.ci, digits)
    ci$high.ci <- signif(ci$high.ci, digits)

    # merge keeps the original order
    if (has_wb) {
      df  <- merge(df, ci, by = c("input", "component"),
                   all.x = TRUE, sort = FALSE)
    } else {
      df  <- merge(df, ci, by = c("input"),
                   all.x = TRUE, sort = FALSE)
    }

  }

  ## ---- printing -----------------------------------------------------------
  cat("\n--- gsaot_indices summary -----------------------------------------\n")
  cat("Method        :", method, "\n")
  cat("No. of inputs :", n_inputs, "\n")
  if (has_boot) {
    cat("Bootstrap     : TRUE (R =", object$R, ", conf =", object$conf, ")\n")
  }
  if (has_wb && has_boot)
    cat("Components    : WB, Advective, Diffusive\n")
  else if (has_wb)
    cat("Components    : WB, Advective, Diffusive\n")

  cat("--------------------------------------------------------------------\n")

  ## show only requested number of rows
  if (!is.null(ranking))
    print(utils::head(df, ranking), row.names = FALSE)
  else
    print(df, row.names = FALSE)

  cat("--------------------------------------------------------------------\n")

  ## ---- return invisibly ---------------------------------------------------
  invisible(list(
    method      = method,
    indices_tbl = df,
    boot        = if (has_boot) list(R = object$R,
                                     conf = object$conf,
                                     type = object$type) else NULL,
    wb          = has_wb
  ))
}

#' Compute confidence intervals for sensitivity indices
#'
#' Computes confidence intervals for a \code{gsaot_indices} object using
#' bootstrap results.
#'
#' @param object An object of class \code{gsaot_indices}, with bootstrap results
#'   included.
#' @param parm A specification of which parameters are to be given confidence
#'   intervals, either a vector of numbers or a vector of names. If missing, all
#'   parameters are considered.
#' @param level (default is 0.95) Confidence level for the interval.
#' @param type (default is \code{"norm"}) Method to compute the confidence interval.
#'   For more information, check the `type` option of [boot::boot.ci()].
#' @param ... Additional arguments (currently unused).
#'
#' @return A data frame with the following columns:
#'   * `input`: Name of the input variable.
#'   * `component`: The index component for Wasserstein-Bures.
#'   * `index`: Estimated indices
#'   * `original`: Original estimates.
#'   * `bias`: Bootstrap bias estimate.
#'   * `low.ci`: Lower bound of the confidence interval.
#'   * `high.ci`: Upper bound of the confidence interval.
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#' y <- y
#'
#' res <- ot_indices_wb(x, y, 10, boot = TRUE, R = 100)
#' confint(res, parm = c(1,3), level = 0.9)
#'
#' @export
confint.gsaot_indices <- function(object,
                                  parm = NULL,
                                  level = 0.95,
                                  type = "norm", ...) {
  # INPUT CHECKS
  # ----------------------------------------------------------------------------
  if (!object$boot)
    stop("The 'gsaot_indices' object has no bootstrap")

  if (!(is.numeric(parm) | is.character(parm) | is.null(parm)))
    stop("'parm' should be a vector of numbers or of names")

  # Identify the parameters for CI computation
  # ----------------------------------------------------------------------------
  if (is.null(parm)) {
    parm <- seq(length(object$indices))
  }
  if (is.character(parm)) {
    parm <- match(parm, names(object$indices))
  }

  K <- length(parm)
  is_wb <- all(c("adv", "diff") %in% names(object))

  # Initialize the return matrices
  # ----------------------------------------------------------------------------
  W <- array(dim = K)
  names(W) <- names(object$indices)[parm]
  V <- array(dim = K)

  if (is_wb) {
    Adv <- array(dim = K)
    names(Adv) <- names(object$indices)[parm]
    Diff <- array(dim = K)
    names(Diff) <- names(object$indices)[parm]

    W_ci <- data.frame(matrix(nrow = K * 3,
                              ncol = 6,
                              dimnames = list(NULL,
                                              c("input", "component",
                                                "original", "bias",
                                                "low.ci", "high.ci"))))
    W_ci$input <- rep(names(W), times = 3)
    W_ci$component <- rep(c("wass-bures", "advective", "diffusive"), each = K)
  } else {
    V <- array(dim = K)
    W_ci <- data.frame(matrix(nrow = K,
                              ncol = 5,
                              dimnames = list(NULL,
                                              c("input", "original", "bias",
                                                "low.ci", "high.ci"))))
    W_ci$input <- names(W)
  }

  V_ci <- list()

  # Compute the statistics
  # ----------------------------------------------------------------------------
  for (k in seq_along(parm)) {
    parm_index <- parm[k]
    M <- ncol(object$separation_measures[[parm_index]])

    W_stats <- bootstats(object$W_boot[[parm_index]], type = type, conf = level)

    # Save indices
    W[k] <- W_stats$index[1]

    if (is_wb) {
      # Save indices decomposition
      Adv[k] <- W_stats$index[2]
      Diff[k] <- W_stats$index[3]

      # Boostrap estimates of the indices
      W_ci[k, 3:6] <- c(W_stats$original[1], W_stats$bias[1],
                        W_stats$low.ci[1], W_stats$high.ci[1])
      W_ci[K + k, 3:6] <- c(W_stats$original[2], W_stats$bias[2],
                            W_stats$low.ci[2], W_stats$high.ci[2])
      W_ci[2 * K + k, 3:6] <- c(W_stats$original[3], W_stats$bias[3],
                                W_stats$low.ci[3], W_stats$high.ci[3])
    } else {
      W_ci[k, 2:5] <- c(W_stats$original[1], W_stats$bias[1],
                        W_stats$low.ci[1], W_stats$high.ci[1])
    }
  }

  # Save indices along with the rest of the statistics
  # ----------------------------------------------------------------------------
  if (is_wb) {
    W_ci$index <- c(W, Adv, Diff)
    W_ci <- W_ci[, c(1, 2, 7, 3:6)]
  } else {
    W_ci$index <- W
    W_ci <- W_ci[, c(1, 6, 2:5)]
  }

  return(W_ci)
}

Try the gsaot package in your browser

Any scripts or data that you put into this service are public.

gsaot documentation built on Aug. 8, 2025, 7:52 p.m.