Nothing
# =============================================================================
# ConditionalSAGE Tests
# =============================================================================
#
# Strategy for test speed:
# - Use ConditionalGaussianSampler (fast) for most tests
# - Use sim_dgp_* functions (4-5 features) instead of friedman1 (10 features)
# - Only use ARF sampler in tests specifically testing ARF behavior
# - ARF-specific tests get skip_on_cran()
# -----------------------------------------------------------------------------
# Basic functionality
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE basic workflow with Gaussian sampler", {
# Use small task (4 features) and fast Gaussian sampler
task = sim_dgp_correlated(n = 100)
sampler = ConditionalGaussianSampler$new(task)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
checkmate::expect_r6(sage, c("FeatureImportanceMethod", "SAGE", "ConditionalSAGE"))
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
})
test_that("ConditionalSAGE works with classification tasks", {
# Binary classification - 2dnormals has 2 features
task_binary = tgen("2dnormals")$generate(n = 50)
sampler = ConditionalGaussianSampler$new(task_binary)
sage_binary = ConditionalSAGE$new(
task = task_binary,
learner = lrn("classif.rpart", predict_type = "prob"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
checkmate::expect_r6(sage_binary, c("FeatureImportanceMethod", "SAGE", "ConditionalSAGE"))
sage_binary$compute()
expect_importance_dt(sage_binary$importance(), features = sage_binary$features)
})
test_that("ConditionalSAGE multiclass classification", {
skip_on_cran() # multiclass with 3 features is slower
task_multi = tgen("cassini")$generate(n = 50)
sampler = ConditionalGaussianSampler$new(task_multi)
sage_multi = ConditionalSAGE$new(
task = task_multi,
learner = lrn("classif.rpart", predict_type = "prob"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
sage_multi$compute()
expect_importance_dt(sage_multi$importance(), features = sage_multi$features)
expect_length(task_multi$class_names, 3L)
})
test_that("ConditionalSAGE featureless learner produces zero importance", {
# Use small task (4 features) and fast Gaussian sampler
task = sim_dgp_correlated(n = 100)
sampler = ConditionalGaussianSampler$new(task)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.featureless"),
measure = msr("regr.mse"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
result = sage$importance()
expect_importance_dt(result, features = sage$features)
# All importance values should be essentially zero
checkmate::expect_numeric(result$importance, lower = -1e-10, upper = 1e-10)
})
# -----------------------------------------------------------------------------
# Sensible results
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE produces sensible ranking", {
# Use sim_dgp_independent (5 features) with Gaussian sampler
task = sim_dgp_independent(n = 200)
sampler = ConditionalGaussianSampler$new(task)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
measure = msr("regr.mse"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
result = sage$importance()
expect_importance_dt(result, features = sage$features)
# Important features should have higher mean importance than unimportant
important_scores = result[grepl("^important", feature)]$importance
unimportant_scores = result[grepl("^unimportant", feature)]$importance
expect_gt(mean(important_scores), mean(unimportant_scores))
})
# -----------------------------------------------------------------------------
# Sampler behavior
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE uses ConditionalARFSampler by default", {
skip_if_not_installed("arf")
# Just check default sampler class - no compute needed
task = sim_dgp_correlated(n = 50)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
n_permutations = 2L,
n_samples = 20L
)
checkmate::expect_r6(sage$sampler, "ConditionalARFSampler")
expect_equal(sage$label, "Conditional SAGE")
})
test_that("ConditionalSAGE with ARF sampler computes correctly", {
skip_on_cran() # ARF sampling is slow
skip_if_not_installed("arf")
# Use small task for ARF test
task = sim_dgp_correlated(n = 50)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
})
test_that("ConditionalSAGE with custom ARF sampler settings", {
skip_on_cran() # ARF sampling is slow
skip_if_not_installed("arf")
task = sim_dgp_correlated(n = 50)
custom_sampler = ConditionalARFSampler$new(task, finite_bounds = "local")
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
measure = msr("regr.mse"),
sampler = custom_sampler,
n_permutations = 2L,
n_samples = 20L
)
checkmate::expect_r6(sage$sampler, "ConditionalSampler")
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
})
# -----------------------------------------------------------------------------
# Parameter validation
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE requires predict_type='prob' for classification", {
task = tgen("2dnormals")$generate(n = 50)
expect_error(
ConditionalSAGE$new(
task = task,
learner = lrn("classif.rpart", predict_type = "response"),
sampler = ConditionalGaussianSampler$new(task)
),
"Classification learners require probability predictions for SAGE."
)
})
# -----------------------------------------------------------------------------
# Batching (skip_on_cran - tests implementation detail)
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE batching produces consistent results", {
skip_on_cran() # tests implementation detail, not core functionality
skip_if_not_installed("withr")
# Use small task with Gaussian sampler for faster batching test
task = sim_dgp_correlated(n = 50)
sampler = ConditionalGaussianSampler$new(task)
# Results should be identical with or without batching
result_batch = withr::with_seed(42, {
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
sampler = ConditionalGaussianSampler$new(task),
n_permutations = 2L,
n_samples = 20L
)
sage$compute(batch_size = 1)
sage$importance()
})
result_normal = withr::with_seed(42, {
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
sampler = ConditionalGaussianSampler$new(task),
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
sage$importance()
})
expect_equal(result_batch$importance, result_normal$importance, tolerance = 1e-10)
})
# -----------------------------------------------------------------------------
# n_samples parameter
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE with custom n_samples", {
# Use small task with Gaussian sampler
task = sim_dgp_correlated(n = 50)
sampler = ConditionalGaussianSampler$new(task)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
sampler = sampler,
n_permutations = 2L,
n_samples = 20L
)
expect_equal(sage$param_set$values$n_samples, 20L)
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
})
# -----------------------------------------------------------------------------
# Convergence tracking (skip_on_cran - tests advanced feature)
# -----------------------------------------------------------------------------
test_that("ConditionalSAGE SE tracking in convergence_history", {
skip_on_cran() # tests convergence tracking feature, not core SAGE
# Use small task with Gaussian sampler
task = sim_dgp_correlated(n = 50)
sampler = ConditionalGaussianSampler$new(task)
sage = ConditionalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
measure = msr("regr.mse"),
sampler = sampler,
n_permutations = 6L,
n_samples = 20L
)
# Compute with early stopping to get convergence history
sage$compute(early_stopping = TRUE, se_threshold = 0.05, check_interval = 2L)
# Check that convergence_history exists and has SE column
expect_false(is.null(sage$convergence_history))
expect_contains(colnames(sage$convergence_history), "se")
# Check structure of convergence_history
expected_cols = c("n_permutations", "feature", "importance", "se")
expect_setequal(colnames(sage$convergence_history), expected_cols)
# SE values should be non-negative and finite
se_values = sage$convergence_history$se
checkmate::expect_numeric(se_values, lower = 0, finite = TRUE)
# All features should be represented in convergence history
expect_setequal(
unique(sage$convergence_history$feature),
sage$features
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.