tests/testthat/test-MarginalSAGE.R

# =============================================================================
# MarginalSAGE Tests
# =============================================================================

# -----------------------------------------------------------------------------
# Basic functionality
# -----------------------------------------------------------------------------

test_that("MarginalSAGE default behavior with minimal parameters", {
	# Use small params for test speed
	test_default_behavior(MarginalSAGE, task_type = "regr", n_permutations = 2L, n_samples = 20L)
})

test_that("MarginalSAGE works with classification tasks", {
	# Binary classification
	task_binary = tgen("2dnormals")$generate(n = 100)
	sage_binary = MarginalSAGE$new(
		task = task_binary,
		learner = lrn("classif.rpart", predict_type = "prob"),
		n_permutations = 2L
	)
	checkmate::expect_r6(sage_binary, c("FeatureImportanceMethod", "SAGE", "MarginalSAGE"))
	sage_binary$compute()
	expect_importance_dt(sage_binary$importance(), features = sage_binary$features)

	# Multiclass classification
	task_multi = tgen("cassini")$generate(n = 100)
	sage_multi = MarginalSAGE$new(
		task = task_multi,
		learner = lrn("classif.rpart", predict_type = "prob"),
		n_permutations = 2L
	)
	sage_multi$compute()
	expect_importance_dt(sage_multi$importance(), features = sage_multi$features)
	expect_length(task_multi$class_names, 3L)
})

test_that("MarginalSAGE featureless learner produces zero importance", {
	# Use small params for test speed
	test_featureless_zero_importance(
		MarginalSAGE,
		task_type = "regr",
		n_permutations = 2L,
		n_samples = 20L
	)
})

# -----------------------------------------------------------------------------
# Sensible results
# -----------------------------------------------------------------------------

test_that("MarginalSAGE friedman1 produces sensible ranking", {
	# Use small params for test speed
	test_friedman1_sensible_ranking(MarginalSAGE, n = 200L, n_permutations = 2L, n_samples = 20L)
})

# -----------------------------------------------------------------------------
# Resampling
# -----------------------------------------------------------------------------

test_that("MarginalSAGE with cross-validation resampling", {
	task = tgen("friedman1")$generate(n = 200)

	sage = MarginalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		resampling = rsmp("cv", folds = 3),
		n_permutations = 2L
	)
	sage$compute()

	expect_importance_dt(sage$importance(), features = sage$features)
	checkmate::expect_data_table(
		sage$scores(),
		types = c("integer", "character", "numeric"),
		nrows = sage$resampling$iters * length(sage$features),
		ncols = 3,
		any.missing = FALSE
	)
})

# -----------------------------------------------------------------------------
# Single feature
# -----------------------------------------------------------------------------

test_that("MarginalSAGE with single feature", {
	task = tgen("friedman1")$generate(n = 100)

	sage = MarginalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		features = "important4",
		n_permutations = 2L
	)
	sage$compute()

	expect_importance_dt(sage$importance(), features = "important4")
	expect_equal(nrow(sage$importance()), 1L)
})

# -----------------------------------------------------------------------------
# n_samples parameter
# -----------------------------------------------------------------------------

test_that("MarginalSAGE with custom n_samples", {
	task = tgen("friedman1")$generate(n = 200)

	sage = MarginalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		n_samples = 30L,
		n_permutations = 2L
	)
	sage$compute()

	expect_importance_dt(sage$importance(), features = sage$features)
})

# -----------------------------------------------------------------------------
# Reproducibility
# -----------------------------------------------------------------------------

test_that("MarginalSAGE reproducibility with same seed", {
	task = tgen("2dnormals")$generate(n = 100)
	learner = lrn("classif.rpart", predict_type = "prob")
	measure = msr("classif.ce")

	set.seed(42)
	sage1 = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		n_permutations = 3L
	)
	sage1$compute()
	result1 = sage1$importance()

	set.seed(42)
	sage2 = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		n_permutations = 3L
	)
	sage2$compute()
	result2 = sage2$importance()

	# Results should be identical with same seed
	expect_equal(result1$importance, result2$importance, tolerance = 1e-10)
})

# -----------------------------------------------------------------------------
# Parameter validation
# -----------------------------------------------------------------------------

test_that("MarginalSAGE parameter validation", {
	task = tgen("friedman1")$generate(n = 50)
	learner = lrn("regr.rpart")

	# n_permutations must be positive integer
	expect_error(MarginalSAGE$new(task = task, learner = learner, n_permutations = 0L))
	expect_error(MarginalSAGE$new(task = task, learner = learner, n_permutations = -1L))
})

test_that("MarginalSAGE requires predict_type='prob' for classification", {
	task = tgen("2dnormals")$generate(n = 50)

	# Should error for classification without predict_type = "prob"
	expect_error(
		MarginalSAGE$new(
			task = task,
			learner = lrn("classif.rpart", predict_type = "response")
		),
		"Classification learners require probability predictions for SAGE."
	)
})

# -----------------------------------------------------------------------------
# Convergence tracking
# -----------------------------------------------------------------------------

test_that("MarginalSAGE SE tracking in convergence_history", {
	task = tgen("friedman1")$generate(n = 30)
	learner = lrn("regr.rpart")
	measure = msr("regr.mse")

	sage = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		n_permutations = 6L,
		n_samples = 20L
	)

	# Compute with early stopping to get convergence history
	sage$compute(early_stopping = TRUE, se_threshold = 0.05, check_interval = 2L)

	# Check that convergence_history exists and has SE column
	expect_false(is.null(sage$convergence_history))
	expect_contains(colnames(sage$convergence_history), "se")

	# Check structure of convergence_history
	expected_cols = c("n_permutations", "feature", "importance", "se")
	expect_setequal(colnames(sage$convergence_history), expected_cols)

	# SE values should be non-negative and finite
	se_values = sage$convergence_history$se
	checkmate::expect_numeric(se_values, lower = 0, finite = TRUE)

	# For each feature, SE should be in a reasonable range
	for (feat in unique(sage$convergence_history$feature)) {
		feat_data = sage$convergence_history[feature == feat]
		feat_data = feat_data[order(n_permutations)]

		if (nrow(feat_data) > 1) {
			# Just check that SE values are in a reasonable range and not exploding
			expect_lt(max(feat_data$se), 10)
			expect_lt(max(abs(diff(feat_data$se))), 5)
		}
	}

	# All features should be represented in convergence history
	expect_setequal(
		unique(sage$convergence_history$feature),
		sage$features
	)
})

test_that("MarginalSAGE SE-based convergence detection", {
	skip_on_cran() # ~1s - tests early stopping feature, not core SAGE

	task = tgen("friedman1")$generate(n = 100)
	learner = lrn("regr.rpart")
	measure = msr("regr.mse")

	sage = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		n_permutations = 10L,
		n_samples = 20L
	)

	# Test with very loose SE threshold (should trigger convergence easily)
	sage$compute(
		early_stopping = TRUE,
		se_threshold = 100.0,
		min_permutations = 5L,
		check_interval = 1L
	)

	# Should converge early because SE will be well below 100.0
	expect_true(sage$converged)
	expect_lte(sage$n_permutations_used, 10L)

	# Reset for next test
	sage$reset()

	# Test with very strict SE threshold (should not converge)
	sage$compute(
		early_stopping = TRUE,
		se_threshold = 0.001,
		min_permutations = 5L,
		check_interval = 1L
	)

	# With very strict SE threshold, should not converge early
	expect_false(sage$converged)

	# Test with moderate SE threshold
	sage$reset()

	sage$compute(
		early_stopping = TRUE,
		se_threshold = 0.1,
		min_permutations = 5L,
		check_interval = 1L
	)

	# Should have convergence history with SE tracking regardless of convergence
	expect_false(is.null(sage$convergence_history))
	expect_contains(colnames(sage$convergence_history), "se")
})

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.