Nothing
# Add .data to the global variables in order to avoid NOTEs from R CMD check
utils::globalVariables(c(".data"))
# gsaot_indices object constructor
gsaot_indices <- function(method,
indices,
bound,
IS,
partitions,
x, y,
solver_optns = NULL,
Adv = NULL,
Diff = NULL,
indices_ci = NULL,
bound_ci = NULL,
IS_ci = NULL,
R = NULL,
type = NULL,
conf = NULL,
W_boot = NULL) {
value <- list(method = method,
indices = indices,
bound = bound,
x = x, y = y,
separation_measures = IS,
partitions = partitions,
boot = FALSE)
if (!is.null(solver_optns))
value[["solver_optns"]] = solver_optns
if (!is.null(Adv)) {
value[["adv"]] <- Adv
value[["diff"]] <- Diff
}
if (!is.null(indices_ci)) {
value[["boot"]] <- TRUE
value[["indices_ci"]] <- indices_ci
value[["separation_measures_ci"]] <- IS_ci
value[["bound_ci"]] <- bound_ci
value[["R"]] <- R
value[["type"]] <- type
value[["conf"]] <- conf
value[["W_boot"]] <- W_boot
}
attr(value, "class") <- "gsaot_indices"
return(value)
}
#' Print optimal transport sensitivity indices information
#'
#' @param x An object generated by \code{\link{ot_indices}},
#' \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return The information contained in argument `x`
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices(x, y, M)
#' print(sensitivity_indices)
#'
print.gsaot_indices <- function(x, ...) {
cat("Method:", x$method, "\n")
cat("\nIndices:\n")
print(x$indices)
if (exists("adv", where = x)) {
cat("\nAdvective component:\n")
print(x$adv)
cat("\nDiffusive component:\n")
print(x$diff)
}
if (x$boot) {
cat("\nType of confidence interval:", x$type, "\n")
cat("Number of replicates:", x$R, "\n")
cat("Confidence level:", x$conf, "\n")
cat("Bootstrap statistics:\n")
print(x$indices_ci)
}
}
#' Plot optimal transport sensitivity indices
#'
#' Plot Optimal Transport based sensitivity indices using `ggplot2` package.
#'
#' @param x An object generated by \code{\link{ot_indices}},
#' \code{\link{ot_indices_1d}}, or \code{\link{ot_indices_wb}}.
#' @param ranking An integer with absolute value less or equal than the number
#' of inputs. If positive, select the first `ranking` inputs per importance.
#' If negative, select the last `ranking` inputs per importance.
#' @param wb_all (default `FALSE`) Logical that defines whether or not to plot
#' the Advective and Diffusive components of the Wasserstein-Bures indices.
#' @param threshold (default `NULL`) A double or and object of class `gsaot_indices`
#' that represents a lower threshold.
#' @param ... Further arguments passed to or from other methods.
#'
#'
#' @returns A \code{ggplot} object that, if called, will print
#' @export
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#'
#' M <- 25
#'
#' # Calculate sensitivity indices
#' sensitivity_indices <- ot_indices_wb(x, y, M)
#' sensitivity_indices
#'
#' plot(sensitivity_indices)
#'
plot.gsaot_indices <- function(x,
ranking = NULL,
wb_all = FALSE,
threshold = NULL, ...) {
# If ranking is defined, plot only the selected inputs
K <- nrow(x$indices)
# Select only the inputs requested by `ranking`
if (!is.null(ranking)) {
# Check if ranking is an integer less than the number of inputs
if (ranking %% 1 == 0 & abs(ranking) <= K) {
inputs_to_plot <-
ifelse(rep(sign(ranking), each = abs(ranking)) > 0,
seq(ranking),
seq(from = K + ranking + 1, to = K))
} else
stop("`ranking` should be an integer with absolute value less than the number of inputs")
} else
inputs_to_plot <- seq(K)
# Find the threshold to be plotted if threshold is defined
if (!inherits(threshold, "gsaot_indices") & !is.double(threshold) & !is.null(threshold))
stop("`threshold` should be an object of class `gsaot_indices` or a double")
if (!is.null(threshold) & !is.double(threshold))
threshold <- threshold$indices[1]
# If the indices are not from ot_indices_wb, print only the indices
# Otherwise, plot the indices, the advective component and the diffusive one
if (!exists("adv", where = x) | !wb_all) {
# Create a data.frame to store all the indices
x_indices <-
data.frame(
Inputs = names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
Component = x$method,
Indices = unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot]
)
x_indices$Inputs <- factor(x_indices$Inputs,
levels = unique(x_indices$Inputs))
# Plot the indices ordered by magnitude
p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
y = .data[["Indices"]],
fill = .data[["Component"]])) +
ggplot2::geom_bar(stat = "identity") +
ggplot2::labs(
title = paste("Indices computed using", x$method, "solver"),
x = "Inputs",
y = "Indices"
) +
ggplot2::guides(fill = "none")
if (x$boot) {
ci_data <- merge(x_indices, x$indices_ci,
by.x = "Inputs", by.y = "input")
# Remove unnecessary components from the data to plot
if (exists("component", where = ci_data))
ci_data <- ci_data[!(ci_data$component %in% c("advective", "diffusive")), ]
p <- p +
ggplot2::geom_errorbar(data = ci_data,
ggplot2::aes(ymin = .data[["low.ci"]],
ymax = .data[["high.ci"]]),
position = ggplot2::position_dodge2(padding = 0.2),
width = .5)
}
} else {
# Create a data.frame to store all the indices
x_indices <- data.frame(Inputs = rep(names(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot], times = 3),
Component = rep(c("wass-bures", "advective", "diffusive"), each = length(inputs_to_plot)),
Indices = c(unname(x$indices[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
unname(x$adv[order(x$indices, decreasing = TRUE)])[inputs_to_plot],
unname(x$diff[order(x$indices, decreasing = TRUE)])[inputs_to_plot]))
x_indices$Inputs <- factor(x_indices$Inputs,
levels = unique(x_indices$Inputs))
x_indices$Component <- factor(x_indices$Component,
levels = unique(x_indices$Component))
p <- ggplot2::ggplot(data = x_indices, ggplot2::aes(x = .data[["Inputs"]],
y = .data[["Indices"]],
fill = .data[["Component"]])) +
ggplot2::geom_bar(
stat = "identity",
position = ggplot2::position_dodge2(padding = 0.2),
width = .5
) +
ggplot2::labs(
title = paste("Indices computed using", x$method, "solver"),
x = "Inputs",
y = "Indices"
)
if (x$boot) {
ci_data <- merge(x_indices, x$indices_ci,
by.x = c("Inputs", "Component"),
by.y = c("input", "component"))
p <- p +
ggplot2::geom_errorbar(data = ci_data,
ggplot2::aes(ymin = .data[["low.ci"]],
ymax = .data[["high.ci"]]),
position = ggplot2::position_dodge2(padding = 0.5),
width = .5)
}
}
if (!is.null(threshold))
p <- p +
ggplot2::geom_hline(yintercept = threshold, linetype = 2)
return(p)
}
#' Summary method for `gsaot_indices` objects
#'
#' @param object An object of class \code{"gsaot_indices"}.
#' @param digits (default: \code{3}) Number of significant digits to print for
#' numeric values.
#' @param ranking An integer with absolute value less or equal than the number
#' of inputs. If positive, select the first `ranking` inputs per importance.
#' @param ... Further arguments (currently ignored).
#'
#' @return (Invisibly) a named \code{list} containing the main elements
#' summarised on screen.
#' @export
summary.gsaot_indices <- function(object, digits = 3, ranking = NULL, ...) {
## ---- basic meta‑information --------------------------------------------
method <- object$method
if (method == "transport") {
method <- ifelse(!is.null(object$solver_optns$method),
object$solver_optns$method,
"networksimplex")
}
n_inputs <- length(object$indices)
has_wb <- all(c("adv", "diff") %in% names(object))
has_boot <- isTRUE(object$boot)
## ---- build a data frame of indices -------------------------------------
ord <- order(object$indices, decreasing = TRUE)
df <- data.frame(
input = names(object$indices)[ord],
index = signif(object$indices[ord], digits)
)
if (has_wb) {
df <- data.frame(input = rep(names(object$indices)[ord], times = 3),
component = rep(c("wass-bures", "advective", "diffusive"), each = n_inputs),
index = signif(c(unname(object$indices)[ord],
unname(object$adv)[ord],
unname(object$diff)[ord]), digits))
}
if (has_boot) {
ci <- object$indices_ci
ci$original <- NULL
ci$bias <- NULL
ci$low.ci <- signif(ci$low.ci, digits)
ci$high.ci <- signif(ci$high.ci, digits)
# merge keeps the original order
if (has_wb) {
df <- merge(df, ci, by = c("input", "component"),
all.x = TRUE, sort = FALSE)
} else {
df <- merge(df, ci, by = c("input"),
all.x = TRUE, sort = FALSE)
}
}
## ---- printing -----------------------------------------------------------
cat("\n--- gsaot_indices summary -----------------------------------------\n")
cat("Method :", method, "\n")
cat("No. of inputs :", n_inputs, "\n")
if (has_boot) {
cat("Bootstrap : TRUE (R =", object$R, ", conf =", object$conf, ")\n")
}
if (has_wb && has_boot)
cat("Components : WB, Advective, Diffusive\n")
else if (has_wb)
cat("Components : WB, Advective, Diffusive\n")
cat("--------------------------------------------------------------------\n")
## show only requested number of rows
if (!is.null(ranking))
print(utils::head(df, ranking), row.names = FALSE)
else
print(df, row.names = FALSE)
cat("--------------------------------------------------------------------\n")
## ---- return invisibly ---------------------------------------------------
invisible(list(
method = method,
indices_tbl = df,
boot = if (has_boot) list(R = object$R,
conf = object$conf,
type = object$type) else NULL,
wb = has_wb
))
}
#' Compute confidence intervals for sensitivity indices
#'
#' Computes confidence intervals for a \code{gsaot_indices} object using
#' bootstrap results.
#'
#' @param object An object of class \code{gsaot_indices}, with bootstrap results
#' included.
#' @param parm A specification of which parameters are to be given confidence
#' intervals, either a vector of numbers or a vector of names. If missing, all
#' parameters are considered.
#' @param level (default is 0.95) Confidence level for the interval.
#' @param type (default is \code{"norm"}) Method to compute the confidence interval.
#' For more information, check the `type` option of [boot::boot.ci()].
#' @param ... Additional arguments (currently unused).
#'
#' @return A data frame with the following columns:
#' * `input`: Name of the input variable.
#' * `component`: The index component for Wasserstein-Bures.
#' * `index`: Estimated indices
#' * `original`: Original estimates.
#' * `bias`: Bootstrap bias estimate.
#' * `low.ci`: Lower bound of the confidence interval.
#' * `high.ci`: Upper bound of the confidence interval.
#'
#' @examples
#' N <- 1000
#'
#' mx <- c(1, 1, 1)
#' Sigmax <- matrix(data = c(1, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 0.5, 1), nrow = 3)
#'
#' x1 <- rnorm(N)
#' x2 <- rnorm(N)
#' x3 <- rnorm(N)
#'
#' x <- cbind(x1, x2, x3)
#' x <- mx + x %*% chol(Sigmax)
#'
#' A <- matrix(data = c(4, -2, 1, 2, 5, -1), nrow = 2, byrow = TRUE)
#' y <- t(A %*% t(x))
#'
#' x <- data.frame(x)
#' y <- y
#'
#' res <- ot_indices_wb(x, y, 10, boot = TRUE, R = 100)
#' confint(res, parm = c(1,3), level = 0.9)
#'
#' @export
confint.gsaot_indices <- function(object,
parm = NULL,
level = 0.95,
type = "norm", ...) {
# INPUT CHECKS
# ----------------------------------------------------------------------------
if (!object$boot)
stop("The 'gsaot_indices' object has no bootstrap")
if (!(is.numeric(parm) | is.character(parm) | is.null(parm)))
stop("'parm' should be a vector of numbers or of names")
# Identify the parameters for CI computation
# ----------------------------------------------------------------------------
if (is.null(parm)) {
parm <- seq(length(object$indices))
}
if (is.character(parm)) {
parm <- match(parm, names(object$indices))
}
K <- length(parm)
is_wb <- all(c("adv", "diff") %in% names(object))
# Initialize the return matrices
# ----------------------------------------------------------------------------
W <- array(dim = K)
names(W) <- names(object$indices)[parm]
V <- array(dim = K)
if (is_wb) {
Adv <- array(dim = K)
names(Adv) <- names(object$indices)[parm]
Diff <- array(dim = K)
names(Diff) <- names(object$indices)[parm]
W_ci <- data.frame(matrix(nrow = K * 3,
ncol = 6,
dimnames = list(NULL,
c("input", "component",
"original", "bias",
"low.ci", "high.ci"))))
W_ci$input <- rep(names(W), times = 3)
W_ci$component <- rep(c("wass-bures", "advective", "diffusive"), each = K)
} else {
V <- array(dim = K)
W_ci <- data.frame(matrix(nrow = K,
ncol = 5,
dimnames = list(NULL,
c("input", "original", "bias",
"low.ci", "high.ci"))))
W_ci$input <- names(W)
}
V_ci <- list()
# Compute the statistics
# ----------------------------------------------------------------------------
for (k in seq_along(parm)) {
parm_index <- parm[k]
M <- ncol(object$separation_measures[[parm_index]])
W_stats <- bootstats(object$W_boot[[parm_index]], type = type, conf = level)
# Save indices
W[k] <- W_stats$index[1]
if (is_wb) {
# Save indices decomposition
Adv[k] <- W_stats$index[2]
Diff[k] <- W_stats$index[3]
# Boostrap estimates of the indices
W_ci[k, 3:6] <- c(W_stats$original[1], W_stats$bias[1],
W_stats$low.ci[1], W_stats$high.ci[1])
W_ci[K + k, 3:6] <- c(W_stats$original[2], W_stats$bias[2],
W_stats$low.ci[2], W_stats$high.ci[2])
W_ci[2 * K + k, 3:6] <- c(W_stats$original[3], W_stats$bias[3],
W_stats$low.ci[3], W_stats$high.ci[3])
} else {
W_ci[k, 2:5] <- c(W_stats$original[1], W_stats$bias[1],
W_stats$low.ci[1], W_stats$high.ci[1])
}
}
# Save indices along with the rest of the statistics
# ----------------------------------------------------------------------------
if (is_wb) {
W_ci$index <- c(W, Adv, Diff)
W_ci <- W_ci[, c(1, 2, 7, 3:6)]
} else {
W_ci$index <- W
W_ci <- W_ci[, c(1, 6, 2:5)]
}
return(W_ci)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.