Nothing
#' @title Shapley Additive Global Importance (SAGE) Base Class
#'
#' @description Base class for SAGE (Shapley Additive Global Importance)
#' feature importance based on Shapley values with marginalization.
#' This is an abstract class - use [MarginalSAGE] or [ConditionalSAGE].
#'
#' @details
#' SAGE uses Shapley values to fairly distribute the total prediction
#' performance among all features. Unlike perturbation-based methods,
#' SAGE marginalizes features by integrating over their distribution.
#' This is approximated by averaging predictions over a reference dataset.
#'
#' **Standard Error Calculation**: The standard errors (SE) reported in
#' `$convergence_history` reflect the uncertainty in Shapley value estimation
#' across different random permutations within a single resampling iteration.
#' These SEs quantify the Monte Carlo sampling error for a fixed trained model
#' and are only valid for inference about the importance of features for that
#' specific model. They do not capture broader uncertainty from model variability
#' across different train/test splits or resampling iterations.
#'
#' @references
#' `r print_bib("lundberg_2020")`
#'
#' @seealso [MarginalSAGE] [ConditionalSAGE]
#'
#' @export
SAGE = R6Class(
"SAGE",
inherit = FeatureImportanceMethod,
public = list(
#' @field n_permutations (`integer(1)`) Number of permutations to sample.
n_permutations = NULL,
#' @field convergence_history ([`data.table`][data.table::data.table]) History of SAGE values during computation.
convergence_history = NULL,
#' @field converged (`logical(1)`) Whether convergence was detected.
converged = FALSE,
#' @field n_permutations_used (`integer(1)`) Actual number of permutations used.
n_permutations_used = NULL,
#' @description
#' Creates a new instance of the SAGE class.
#' @param task,learner,measure,resampling,features Passed to FeatureImportanceMethod.
#' @param n_permutations (`integer(1)`: `10L`) Number of permutations to sample for SAGE value estimation.
#' The total number of evaluated coalitions is `1 (empty) + n_permutations * n_features`.
#' @param batch_size (`integer(1)`: `5000L`) Maximum number of observations to process in a single prediction call.
#' @param n_samples (`integer(1)`: `100L`) Number of samples to use for marginalizing out-of-coalition features.
#' For [MarginalSAGE], this is the number of marginal data samples ("background data" in other implementations).
#' For [ConditionalSAGE], this is the number of conditional samples per test instance retrieved from `sampler`.
#' @param early_stopping (`logical(1)`: `TRUE`) Whether to enable early stopping based on convergence detection.
#' @param se_threshold (`numeric(1)`: `0.01`) Convergence threshold for relative standard error.
#' Convergence is detected when the maximum relative SE across all features falls below this threshold.
#' Relative SE is calculated as SE divided by the range of importance values (max - min),
#' making it scale-invariant across different loss metrics.
#' Default of `0.01` means convergence when relative SE is below 1% of the importance range.
#' @param min_permutations (`integer(1)`: `10L`) Minimum permutations before checking for convergence. Convergence is judged based on the standard errors of the estimated SAGE values,
#' which requires a sufficiently large number of samples (i.e., evaluated coalitions).
#' @param check_interval (`integer(1)`: `1L`) Check convergence every N permutations.
initialize = function(
task,
learner,
measure = NULL,
resampling = NULL,
features = NULL,
n_permutations = 10L,
batch_size = 5000L,
n_samples = 100L,
early_stopping = TRUE,
se_threshold = 0.01,
min_permutations = 10L,
check_interval = 1L
) {
super$initialize(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
features = features,
label = "Shapley Additive Global Importance"
)
self$n_permutations = checkmate::assert_int(n_permutations, lower = 1L)
# For classification tasks, require predict_type = "prob"
if (self$task$task_type == "classif") {
if (learner$predict_type != "prob") {
cli::cli_abort(c(
"Classification learners require probability predictions for SAGE.",
"i" = "Please set {.code learner$configure(predict_type = \"prob\")} before using SAGE."
))
}
}
# Set parameters
ps = ps(
n_permutations = paradox::p_int(lower = 1L, default = 10L),
batch_size = paradox::p_int(lower = 1L, default = 5000L),
n_samples = paradox::p_int(lower = 1L, default = 100L),
early_stopping = paradox::p_lgl(default = TRUE),
se_threshold = paradox::p_dbl(lower = 0, upper = 1, default = 0.01),
min_permutations = paradox::p_int(lower = 1L, default = 3L),
check_interval = paradox::p_int(lower = 1L, default = 1L)
)
ps$values$n_permutations = n_permutations
ps$values$batch_size = batch_size
ps$values$n_samples = n_samples
ps$values$early_stopping = early_stopping
ps$values$se_threshold = se_threshold
ps$values$min_permutations = min_permutations
ps$values$check_interval = check_interval
self$param_set = ps
},
#' @description
#' Compute SAGE values.
#' @param store_backends (`logical(1)`) Whether to store data backends.
#' @param batch_size (`integer(1)`: `5000L`) Maximum number of observations to process in a single prediction call.
#' @param early_stopping (`logical(1)`: `TRUE`) Whether to check for convergence and stop early.
#' @param se_threshold (`numeric(1)`: `0.01`) Convergence threshold for relative standard error.
#' SE is normalized by the range of importance values (max - min) to make convergence
#' detection scale-invariant. Default `0.01` means convergence when relative SE < 1%.
#' @param min_permutations (`integer(1)`: `10L`) Minimum permutations before checking convergence.
#' @param check_interval (`integer(1)`: `1L`) Check convergence every N permutations.
compute = function(
store_backends = TRUE,
batch_size = NULL,
early_stopping = NULL,
se_threshold = NULL,
min_permutations = NULL,
check_interval = NULL
) {
# Reset convergence tracking
self$convergence_history = NULL
self$converged = FALSE
self$n_permutations_used = NULL
# Resolve parameters using hierarchical resolution
batch_size = resolve_param(batch_size, self$param_set$values$batch_size, 5000L)
early_stopping = resolve_param(
early_stopping,
self$param_set$values$early_stopping,
TRUE
)
se_threshold = resolve_param(
se_threshold,
self$param_set$values$se_threshold,
0.01
)
min_permutations = resolve_param(
min_permutations,
self$param_set$values$min_permutations,
3L
)
check_interval = resolve_param(check_interval, self$param_set$values$check_interval, 1L)
# Initial resampling to get trained learners
rr = assemble_rr(
task = self$task,
learner = self$learner,
resampling = self$resampling,
store_models = TRUE,
store_backends = store_backends
)
# Store results
self$resample_result = rr
# For convergence tracking, we'll use the first resampling iteration
# (convergence is about permutation count, not resampling)
iter_for_convergence = 1L
# Compute SAGE values for convergence tracking (first iteration)
first_result = private$.compute_sage_scores(
learner = rr$learners[[iter_for_convergence]],
test_dt = self$task$data(rows = rr$resampling$test_set(iter_for_convergence)),
n_permutations = self$n_permutations,
batch_size = batch_size,
early_stopping = early_stopping,
se_threshold = se_threshold,
min_permutations = min_permutations,
check_interval = check_interval
)
# Extract convergence data from first iteration
# `convergence_data` exists even if early_stopping = FALSE
self$convergence_history = first_result$convergence_data$convergence_history
self$converged = first_result$convergence_data$converged
self$n_permutations_used = first_result$convergence_data$n_permutations_used
# If we have multiple resampling iterations, compute the rest without convergence tracking
if (self$resampling$iters > 1) {
remaining_results = lapply(seq_len(self$resampling$iters)[-iter_for_convergence], \(iter) {
private$.compute_sage_scores(
learner = rr$learners[[iter]],
test_dt = self$task$data(rows = rr$resampling$test_set(iter)),
# n_permutations_used is either same as n_permutations (if early_stopping = FALSE)
# or its a smaller value if early_stopping = TRUE and it stopped early
# The fallback to self$n_permutations should not be needed
n_permutations = self$n_permutations_used %||% self$n_permutations,
batch_size = batch_size,
# Only track convergence etc. for first iteration
early_stopping = FALSE
)
})
# Extract scores from all results (always list format now)
all_scores = c(list(first_result$scores), lapply(remaining_results, function(x) x$scores))
} else {
all_scores = list(first_result$scores)
}
# Combine results across resampling iterations
scores = rbindlist(all_scores, idcol = "iter_rsmp")
# iter_rsmp, feature, importance -- score_baseline or so don't apply here
private$.scores = scores
},
#' @description
#' Plot convergence history of SAGE values.
#' @param features (`character` | `NULL`) Features to plot. If NULL, plots all features.
#' @return A [ggplot2][ggplot2::ggplot] object
plot_convergence = function(features = NULL) {
require_package("ggplot2")
if (is.null(self$convergence_history)) {
cli::cli_abort("No convergence history available. Run $compute() first.")
}
# Create a copy to avoid modifying the original
plot_data = copy(self$convergence_history)
if (!is.null(features)) {
plot_data = plot_data[feature %in% features]
}
p = ggplot2::ggplot(
plot_data,
ggplot2::aes(x = n_permutations, y = importance, fill = feature, color = feature)
) +
ggplot2::geom_ribbon(
ggplot2::aes(ymin = importance - se, ymax = importance + se),
alpha = 1 / 3
) +
ggplot2::geom_line(linewidth = 1) +
ggplot2::geom_point(size = 2) +
ggplot2::labs(
title = "SAGE Value Convergence",
subtitle = if (self$converged) {
sprintf(
"Converged after %d permutations (saved %d)",
self$n_permutations_used,
self$n_permutations - self$n_permutations_used
)
} else {
sprintf("Completed all %d permutations", self$n_permutations)
},
x = "Number of Permutations",
y = "SAGE Value",
color = "Feature",
fill = "Feature"
) +
ggplot2::theme_minimal(base_size = 14)
if (self$converged) {
p = p +
ggplot2::geom_vline(
xintercept = self$n_permutations_used,
linetype = "dashed",
color = "red",
alpha = 0.5
)
}
p
}
),
private = list(
# This function computes the SAGE values for a single resampling iteration.
# It iterates through permutations of features, evaluates coalitions, and calculates marginal contributions.
.compute_sage_scores = function(
learner,
test_dt,
n_permutations,
batch_size = NULL,
early_stopping = FALSE,
se_threshold = 0.01,
min_permutations = 10L,
check_interval = 1L
) {
# Initialize numeric vectors to store marginal contributions and their squares for variance calculation.
# We track both sum and sum of squares to calculate running variance and standard errors.
sage_values = numeric(length(self$features)) # Sum of marginal contributions
sage_values_sq = numeric(length(self$features)) # Sum of squared marginal contributions
names(sage_values) = self$features
names(sage_values_sq) = self$features
# Pre-generate `n_permutations` permutations upfront
# Relevant for reproducibility, especially when using early stopping or parallel processing.
# Example: if self$features = c("x1", "x2", "x3") and n_permutations = 2,
# all_permutations might be list(c("x2", "x1", "x3"), c("x3", "x1", "x2"))
all_permutations = replicate(n_permutations, sample(self$features), simplify = FALSE)
# Initialize variables for iterative checkpoint-based computation.
# This allows for early stopping based on convergence and provides progress updates.
convergence_history = list() # Stores SAGE values at each checkpoint for convergence tracking
n_completed = 0 # Number of permutations processed so far
converged = FALSE # Flag to indicate if convergence has been detected
baseline_loss = NULL # Loss of the empty coalition (model with no features / all features marginalized)
# Calculate total checkpoints for progress tracking.
# A checkpoint is a group of 'check_interval' permutations.
total_checkpoints = ceiling(n_permutations / check_interval)
current_checkpoint = 0
# Start checkpoint-based progress bar if progress display is enabled.
if (xplain_opt("progress")) {
cli::cli_progress_bar(
"Computing SAGE values",
total = total_checkpoints
)
}
# Main loop: Process permutations in checkpoints until all permutations are done or convergence is reached.
while (n_completed < n_permutations && !converged) {
# Determine the size of the current checkpoint.
# This ensures that the last checkpoint processes only the remaining permutations.
checkpoint_size = min(check_interval, n_permutations - n_completed)
# Define the indices of permutations to be processed in this checkpoint.
checkpoint_perms = (n_completed + 1):(n_completed + checkpoint_size)
# Get the actual permutation sequences for this checkpoint from the pre-generated list.
checkpoint_permutations = all_permutations[checkpoint_perms]
# Prepare lists to store unique coalitions and their mapping back to permutations.
# This avoids redundant evaluations of the same coalition across different permutations/steps.
checkpoint_coalitions = list() # Stores unique feature coalitions (e.g., c("x1", "x2"))
checkpoint_coalition_map = list() # Maps coalition index to its permutation and step (e.g., list(checkpoint_perm_idx = 1, step = 2))
coalition_idx = 1 # Counter for unique coalitions
# Add the empty coalition ({}) only for the very first checkpoint.
# The loss of the empty coalition serves as the baseline for marginal contributions.
if (n_completed == 0) {
checkpoint_coalitions[[1]] = character(0) # Represents the empty set of features
checkpoint_coalition_map[[1]] = list(checkpoint_perm_idx = 0, step = 0) # Special mapping for baseline
coalition_idx = 2 # Start next coalition index from 2
}
# Iterate through each permutation in the current checkpoint to build all necessary coalitions.
# For each permutation (e.g., P = (x2, x1, x3)), we generate coalitions:
# {}, {x2}, {x2, x1}, {x2, x1, x3}
for (i in seq_along(checkpoint_permutations)) {
perm_features = checkpoint_permutations[[i]] # Current permutation (e.g., c("x2", "x1", "x3"))
# Build growing coalitions for the current permutation.
for (j in seq_along(perm_features)) {
# 'coalition' is the set of features considered up to the current step 'j'.
# Example: for P=(x2, x1, x3):
# j=1: coalition = c("x2")
# j=2: coalition = c("x2", "x1")
# j=3: coalition = c("x2", "x1", "x3")
coalition = perm_features[seq_len(j)]
# Store the coalition and its mapping. This allows us to retrieve the loss
# for this specific coalition later from the batch evaluation results.
checkpoint_coalitions[[coalition_idx]] = coalition
checkpoint_coalition_map[[coalition_idx]] = list(checkpoint_perm_idx = i, step = j)
coalition_idx = coalition_idx + 1
}
}
# Update progress: indicate that a new checkpoint is starting.
current_checkpoint = current_checkpoint + 1
n_coalitions_in_checkpoint = length(checkpoint_coalitions) # Total unique coalitions to evaluate in this batch
# Evaluate all unique coalitions collected in this checkpoint in a single batch.
# This is a performance optimization to minimize prediction calls to the learner.
checkpoint_losses = private$.evaluate_coalitions_batch(
learner,
test_dt,
checkpoint_coalitions,
batch_size
)
# Update progress bar.
if (xplain_opt("progress")) {
cli::cli_progress_update(inc = 1)
}
# Store the baseline loss (loss of the empty coalition) from the first checkpoint.
# This is the model's performance when no features are available.
if (n_completed == 0) {
baseline_loss = checkpoint_losses[1] # The first element is always the empty coalition's loss
}
# Process the results from the current checkpoint to calculate marginal contributions.
for (i in seq_along(checkpoint_permutations)) {
perm_features = checkpoint_permutations[[i]] # Current permutation (e.g., c("x2", "x1", "x3"))
prev_loss = baseline_loss # Start with baseline loss for the first feature in permutation
# Iterate through features in the current permutation to calculate their marginal contributions.
for (j in seq_along(perm_features)) {
feature = perm_features[j] # The feature being added at this step (e.g., "x2", then "x1", then "x3")
# Find the index of the current coalition's loss in the 'checkpoint_losses' vector.
# This uses the 'checkpoint_coalition_map' to link back to the batched results.
coalition_lookup_idx = which(sapply(checkpoint_coalition_map, function(x) {
x$checkpoint_perm_idx == i && x$step == j
}))
# Get the loss for the current coalition (e.g., loss({x2}), then loss({x2, x1}), etc.).
current_loss = checkpoint_losses[coalition_lookup_idx]
# Calculate the marginal contribution of the 'feature' just added.
# Contribution = (Loss without feature) - (Loss with feature)
# A smaller loss is better, so a positive contribution means the feature improved performance.
marginal_contribution = prev_loss - current_loss
# Add this marginal contribution to the total SAGE value for the 'feature'.
# Also track the squared contribution for variance calculation.
sage_values[feature] = sage_values[feature] + marginal_contribution
sage_values_sq[feature] = sage_values_sq[feature] + marginal_contribution^2
# Update 'prev_loss' for the next iteration in this permutation.
# The current coalition's loss becomes the 'previous' loss for the next feature's contribution.
prev_loss = current_loss
}
}
# Update the count of completed permutations.
n_completed = n_completed + checkpoint_size
# Calculate the current average SAGE values and standard errors based on completed permutations.
current_avg = sage_values / n_completed
# Calculate running variance and standard errors for each feature
# Variance = E[X^2] - E[X]^2, SE = sqrt(Var / n)
current_variance = (sage_values_sq / n_completed) - (current_avg^2)
# Ensure variance is non-negative (numerical precision issues)
current_variance[current_variance < 0] = 0
current_se = sqrt(current_variance / n_completed)
if (xplain_opt("debug")) {
cli::cli_alert_info("SAGE values after {.val {n_completed}} permutations")
cli::cli_ol(c(
"SAGE values: {.val {round(current_avg, 4)}}",
"current SE: {.val {round(current_se, 3)}}",
"Completed: {.val {n_completed}}"
))
}
# Store the current average SAGE values and standard errors in the convergence history.
# Used for plotting, early stopping, and uncertainty quantification.
checkpoint_history = data.table(
n_permutations = n_completed,
feature = names(current_avg),
importance = as.numeric(current_avg),
se = as.numeric(current_se)
)
convergence_history[[length(convergence_history) + 1]] = checkpoint_history
# Check for convergence if early stopping is enabled and enough permutations have been processed.
if (early_stopping && n_completed >= min_permutations && length(convergence_history) > 1) {
# Get SAGE values from the current checkpoint.
curr_checkpoint = convergence_history[[length(convergence_history)]]
# Ensure features are in the same order for comparison.
curr_checkpoint_ordered = copy(curr_checkpoint)[order(feature)]
curr_importance_values = curr_checkpoint_ordered$importance
curr_se_values = curr_checkpoint_ordered$se
# Calculate range of importance values across all features (matching fippy)
importance_range = max(curr_importance_values, na.rm = TRUE) -
min(curr_importance_values, na.rm = TRUE)
# Normalize SE by range to get relative SE (matching fippy's formula)
# fippy: ratio = SE / range, convergence if max(ratio) < threshold
# https://github.com/gcskoenig/fippy/blob/a7a37aa5511f7074ead3289c89b1ae80036982cb/src/fippy/explainers/utils.py#L40-L42
if (importance_range > 0 && is.finite(importance_range)) {
relative_se_values = curr_se_values / importance_range
max_relative_se = max(relative_se_values, na.rm = TRUE)
} else {
# If range is 0 or invalid, use absolute SE (fallback)
max_relative_se = max(curr_se_values, na.rm = TRUE)
}
# Check convergence: max relative SE below threshold
converged = is.finite(max_relative_se) && max_relative_se < se_threshold
convergence_msg = c(
"v" = "SAGE converged after {.val {n_completed}} permutations",
"i" = "Maximum relative SE: {.val {round(max_relative_se, 4)}} (threshold: {.val {se_threshold}})",
"i" = "Saved {.val {n_permutations - n_completed}} permutations"
)
if (xplain_opt("verbose") && converged) {
cli::cli_inform(convergence_msg)
}
}
}
# Close the progress bar.
if (xplain_opt("progress")) {
cli::cli_progress_done()
}
# Calculate the final average SAGE values based on all completed permutations.
final_sage_values = sage_values / n_completed
# Return the computed scores and convergence data.
list(
scores = data.table(
feature = names(final_sage_values),
importance = as.numeric(final_sage_values)
),
convergence_data = list(
convergence_history = if (length(convergence_history) > 0) {
rbindlist(convergence_history)
} else {
NULL
},
converged = converged,
n_permutations_used = n_completed
)
)
},
# Template method: Defines the complete prediction and aggregation pipeline
# Subclasses only need to implement .expand_coalitions_data()
.evaluate_coalitions_batch = function(learner, test_dt, all_coalitions, batch_size = NULL) {
n_test = nrow(test_dt)
if (xplain_opt("debug")) {
cli::cli_inform("Evaluating {.val {length(all_coalitions)}} coalitions")
}
# STEP 1: Subclass-specific data expansion (abstract method)
# combined data has rows `n_samples * nrow(test_dt) * length(all_coalitions)`
# Full coalition -> return is just test_dt
combined_data = private$.expand_coalitions_data(test_dt, all_coalitions)
# STEPS 2-5: Shared processing pipeline using general utilities
predictions = sage_batch_predict(
learner,
combined_data,
self$task,
batch_size,
self$task$task_type
)
if (anyNA(predictions)) {
cli::cli_warn("Encountered missing values in model prediction")
}
avg_preds = sage_aggregate_predictions(
combined_data,
predictions,
self$task$task_type,
self$task$class_names
)
# Private method (needs self$task and self$measure)
coalition_losses = private$.calculate_coalition_losses(avg_preds, n_test, test_dt)
coalition_losses
},
# Abstract method - must be implemented by subclasses
# Returns: data.table with all feature columns plus .coalition_id and .test_instance_id
.expand_coalitions_data = function(test_dt, all_coalitions) {
cli::cli_abort(c(
"Abstract method not implemented",
"i" = "Subclasses must implement {.fn .expand_coalitions_data}",
"i" = "This method should return a data.table with:",
"*" = "All feature columns (with marginalized features replaced/sampled)",
"*" = "{.field .coalition_id}: integer identifying which coalition",
"*" = "{.field .test_instance_id}: integer identifying original test instance"
))
},
# Private method - needs self$task and self$measure
# Calculates losses from averaged predictions for each coalition
.calculate_coalition_losses = function(avg_preds, n_test, test_dt) {
n_coalitions = length(unique(avg_preds$.coalition_id))
coalition_losses = numeric(n_coalitions)
for (i in seq_len(n_coalitions)) {
coalition_data = avg_preds[.coalition_id == i]
pred_obj = if (self$task$task_type == "classif") {
PredictionClassif$new(
row_ids = seq_len(n_test),
truth = test_dt[[self$task$target_names]],
prob = as.matrix(coalition_data[, .SD, .SDcols = self$task$class_names])
)
} else {
PredictionRegr$new(
row_ids = seq_len(n_test),
truth = test_dt[[self$task$target_names]],
response = coalition_data$avg_pred
)
}
coalition_losses[i] = pred_obj$score(self$measure)
}
coalition_losses
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.