Nothing
#' @title Conditional SAGE
#'
#' @description [SAGE] with conditional sampling (features are "marginalized" conditionally).
#' Uses [ConditionalARFSampler] as default [ConditionalSampler].
#'
#' @seealso [MarginalSAGE]
#'
#' @examplesIf requireNamespace("ranger", quietly = TRUE) && requireNamespace("mlr3learners", quietly = TRUE) && requireNamespace("arf", quietly = TRUE)
#' library(mlr3)
#' task = tgen("friedman1")$generate(n = 100)
#'
#' \donttest{
#' # Using default ConditionalARFSampler (also handles all mixed data)
#' sage = ConditionalSAGE$new(
#' task = task,
#' learner = lrn("regr.ranger", num.trees = 50),
#' measure = msr("regr.mse"),
#' n_permutations = 3L,
#' n_samples = 20
#' )
#' sage$compute()
#' }
#' \donttest{
#' # For alternative conditional samplers:
#' custom_sampler = ConditionalGaussianSampler$new(
#' task = task
#' )
#' sage_custom = ConditionalSAGE$new(
#' task = task,
#' learner = lrn("regr.ranger", num.trees = 50),
#' measure = msr("regr.mse"),
#' n_permutations = 5L,
#' n_samples = 20,
#' sampler = custom_sampler
#' )
#' sage_custom$compute()
#' }
#' @export
ConditionalSAGE = R6Class(
"ConditionalSAGE",
inherit = SAGE,
public = list(
#' @field sampler ([ConditionalSampler]) Sampler for conditional marginalization.
sampler = NULL,
#' @description
#' Creates a new instance of the ConditionalSAGE class.
#' @param task,learner,measure,resampling,features,n_permutations,batch_size,n_samples,early_stopping,se_threshold,min_permutations,check_interval Passed to [SAGE].
#' @param sampler ([ConditionalSampler]) Optional custom sampler. Defaults to [ConditionalARFSampler].
initialize = function(
task,
learner,
measure = NULL,
resampling = NULL,
features = NULL,
n_permutations = 10L,
sampler = NULL,
batch_size = 5000L,
n_samples = 100L,
early_stopping = FALSE,
se_threshold = 0.01,
min_permutations = 10L,
check_interval = 1L
) {
# Use ConditionalARFSampler by default
if (is.null(sampler)) {
sampler = ConditionalARFSampler$new(task)
if (xplain_opt("verbose")) {
cli::cli_alert_info(
"No {.cls ConditionalSampler} provided, using {.cls ConditionalARFSampler} with default settings."
)
}
} else {
checkmate::assert_class(sampler, "ConditionalSampler")
}
super$initialize(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
features = features,
n_permutations = n_permutations,
batch_size = batch_size,
n_samples = n_samples,
early_stopping = early_stopping,
se_threshold = se_threshold,
min_permutations = min_permutations,
check_interval = check_interval
)
# Store sampler (specific to ConditionalSAGE)
self$sampler = sampler
self$label = "Conditional SAGE"
}
),
private = list(
# Conditional-specific data expansion: Sample from conditional distribution
# For each coalition, sample marginalized features conditioned on coalition features
# Uses n_samples to create multiple samples per test instance for averaging
.expand_coalitions_data = function(test_dt, all_coalitions) {
n_test = nrow(test_dt)
all_coalition_data = vector("list", length(all_coalitions))
# For each coalition, sample from conditional distribution P(X_{-S} | X_S)
for (i in seq_along(all_coalitions)) {
coalition = all_coalitions[[i]]
marginalize_features = setdiff(self$features, coalition)
if (length(marginalize_features) > 0) {
# Expand test data to sample multiple times for averaging
# Each test instance is replicated n_samples times
test_dt_expanded = test_dt[rep(
seq_len(n_test),
each = self$param_set$values$n_samples
)]
# Sample conditionally using a sampler
# Returns test_dt with "marginalized" features replaced by conditional samples
# sampler: P(marginalize_features | coalition_features)
marginalized_test = self$sampler$sample_newdata(
feature = marginalize_features,
newdata = test_dt_expanded,
conditioning_set = coalition
)
# Add test instance ID for tracking which original test instance each row belongs to
# Must be done AFTER sampling since sampler does not preserve extra columns
marginalized_test[,
.test_instance_id := rep(
seq_len(n_test),
each = self$param_set$values$n_samples
)
]
} else {
# If marginalize_features is empty then we evaluate the coalition of all features
# in the task, so there's no room for any sampling and we just take all data
# no need to duplicate+average anything since prediction is deterministic (usually (I hope (right?)))
marginalized_test = copy(test_dt)
marginalized_test[, .test_instance_id := seq_len(n_test)]
}
# Add coalition ID for tracking which coalition this data belongs to
marginalized_test[, .coalition_id := i]
all_coalition_data[[i]] = marginalized_test
}
# Combine all coalitions into single data.table for batch prediction
rbindlist(all_coalition_data)
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.