tests/testthat/test-ConditionalSAGE.R

# =============================================================================
# ConditionalSAGE Tests
# =============================================================================
#
# Strategy for test speed:
# - Use ConditionalGaussianSampler (fast) for most tests
# - Use sim_dgp_* functions (4-5 features) instead of friedman1 (10 features)
# - Only use ARF sampler in tests specifically testing ARF behavior
# - ARF-specific tests get skip_on_cran()

# -----------------------------------------------------------------------------
# Basic functionality
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE basic workflow with Gaussian sampler", {
	# Use small task (4 features) and fast Gaussian sampler
	task = sim_dgp_correlated(n = 100)
	sampler = ConditionalGaussianSampler$new(task)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)

	checkmate::expect_r6(sage, c("FeatureImportanceMethod", "SAGE", "ConditionalSAGE"))
	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
})

test_that("ConditionalSAGE works with classification tasks", {
	# Binary classification - 2dnormals has 2 features
	task_binary = tgen("2dnormals")$generate(n = 50)
	sampler = ConditionalGaussianSampler$new(task_binary)

	sage_binary = ConditionalSAGE$new(
		task = task_binary,
		learner = lrn("classif.rpart", predict_type = "prob"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)
	checkmate::expect_r6(sage_binary, c("FeatureImportanceMethod", "SAGE", "ConditionalSAGE"))
	sage_binary$compute()
	expect_importance_dt(sage_binary$importance(), features = sage_binary$features)
})

test_that("ConditionalSAGE multiclass classification", {
	skip_on_cran() # multiclass with 3 features is slower
	task_multi = tgen("cassini")$generate(n = 50)
	sampler = ConditionalGaussianSampler$new(task_multi)

	sage_multi = ConditionalSAGE$new(
		task = task_multi,
		learner = lrn("classif.rpart", predict_type = "prob"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)
	sage_multi$compute()
	expect_importance_dt(sage_multi$importance(), features = sage_multi$features)
	expect_length(task_multi$class_names, 3L)
})

test_that("ConditionalSAGE featureless learner produces zero importance", {
	# Use small task (4 features) and fast Gaussian sampler
	task = sim_dgp_correlated(n = 100)
	sampler = ConditionalGaussianSampler$new(task)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.featureless"),
		measure = msr("regr.mse"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)

	sage$compute()
	result = sage$importance()

	expect_importance_dt(result, features = sage$features)
	# All importance values should be essentially zero
	checkmate::expect_numeric(result$importance, lower = -1e-10, upper = 1e-10)
})

# -----------------------------------------------------------------------------
# Sensible results
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE produces sensible ranking", {
	# Use sim_dgp_independent (5 features) with Gaussian sampler
	task = sim_dgp_independent(n = 200)
	sampler = ConditionalGaussianSampler$new(task)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		measure = msr("regr.mse"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)

	sage$compute()
	result = sage$importance()
	expect_importance_dt(result, features = sage$features)

	# Important features should have higher mean importance than unimportant
	important_scores = result[grepl("^important", feature)]$importance
	unimportant_scores = result[grepl("^unimportant", feature)]$importance
	expect_gt(mean(important_scores), mean(unimportant_scores))
})

# -----------------------------------------------------------------------------
# Sampler behavior
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE uses ConditionalARFSampler by default", {
	skip_if_not_installed("arf")

	# Just check default sampler class - no compute needed
	task = sim_dgp_correlated(n = 50)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		n_permutations = 2L,
		n_samples = 20L
	)

	checkmate::expect_r6(sage$sampler, "ConditionalARFSampler")
	expect_equal(sage$label, "Conditional SAGE")
})

test_that("ConditionalSAGE with ARF sampler computes correctly", {
	skip_on_cran() # ARF sampling is slow
	skip_if_not_installed("arf")

	# Use small task for ARF test
	task = sim_dgp_correlated(n = 50)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		n_permutations = 2L,
		n_samples = 20L
	)

	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
})

test_that("ConditionalSAGE with custom ARF sampler settings", {
	skip_on_cran() # ARF sampling is slow
	skip_if_not_installed("arf")

	task = sim_dgp_correlated(n = 50)
	custom_sampler = ConditionalARFSampler$new(task, finite_bounds = "local")

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		measure = msr("regr.mse"),
		sampler = custom_sampler,
		n_permutations = 2L,
		n_samples = 20L
	)

	checkmate::expect_r6(sage$sampler, "ConditionalSampler")
	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
})

# -----------------------------------------------------------------------------
# Parameter validation
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE requires predict_type='prob' for classification", {
	task = tgen("2dnormals")$generate(n = 50)

	expect_error(
		ConditionalSAGE$new(
			task = task,
			learner = lrn("classif.rpart", predict_type = "response"),
			sampler = ConditionalGaussianSampler$new(task)
		),
		"Classification learners require probability predictions for SAGE."
	)
})

# -----------------------------------------------------------------------------
# Batching (skip_on_cran - tests implementation detail)
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE batching produces consistent results", {
	skip_on_cran() # tests implementation detail, not core functionality
	skip_if_not_installed("withr")

	# Use small task with Gaussian sampler for faster batching test
	task = sim_dgp_correlated(n = 50)
	sampler = ConditionalGaussianSampler$new(task)

	# Results should be identical with or without batching
	result_batch = withr::with_seed(42, {
		sage = ConditionalSAGE$new(
			task = task,
			learner = lrn("regr.rpart"),
			sampler = ConditionalGaussianSampler$new(task),
			n_permutations = 2L,
			n_samples = 20L
		)
		sage$compute(batch_size = 1)
		sage$importance()
	})

	result_normal = withr::with_seed(42, {
		sage = ConditionalSAGE$new(
			task = task,
			learner = lrn("regr.rpart"),
			sampler = ConditionalGaussianSampler$new(task),
			n_permutations = 2L,
			n_samples = 20L
		)
		sage$compute()
		sage$importance()
	})

	expect_equal(result_batch$importance, result_normal$importance, tolerance = 1e-10)
})

# -----------------------------------------------------------------------------
# n_samples parameter
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE with custom n_samples", {
	# Use small task with Gaussian sampler
	task = sim_dgp_correlated(n = 50)
	sampler = ConditionalGaussianSampler$new(task)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		sampler = sampler,
		n_permutations = 2L,
		n_samples = 20L
	)
	expect_equal(sage$param_set$values$n_samples, 20L)

	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
})

# -----------------------------------------------------------------------------
# Convergence tracking (skip_on_cran - tests advanced feature)
# -----------------------------------------------------------------------------

test_that("ConditionalSAGE SE tracking in convergence_history", {
	skip_on_cran() # tests convergence tracking feature, not core SAGE

	# Use small task with Gaussian sampler
	task = sim_dgp_correlated(n = 50)
	sampler = ConditionalGaussianSampler$new(task)

	sage = ConditionalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		measure = msr("regr.mse"),
		sampler = sampler,
		n_permutations = 6L,
		n_samples = 20L
	)

	# Compute with early stopping to get convergence history
	sage$compute(early_stopping = TRUE, se_threshold = 0.05, check_interval = 2L)

	# Check that convergence_history exists and has SE column
	expect_false(is.null(sage$convergence_history))
	expect_contains(colnames(sage$convergence_history), "se")

	# Check structure of convergence_history
	expected_cols = c("n_permutations", "feature", "importance", "se")
	expect_setequal(colnames(sage$convergence_history), expected_cols)

	# SE values should be non-negative and finite
	se_values = sage$convergence_history$se
	checkmate::expect_numeric(se_values, lower = 0, finite = TRUE)

	# All features should be represented in convergence history
	expect_setequal(
		unique(sage$convergence_history$feature),
		sage$features
	)
})

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.