tests/testthat/test-MarginalPermutationSampler.R

# Tests for MarginalPermutationSampler

test_that("MarginalPermutationSampler initialization works", {
	task = tgen("circle", d = 5)$generate(n = 100)
	sampler = MarginalPermutationSampler$new(task)

	expect_s3_class(sampler, "MarginalPermutationSampler")
	expect_s3_class(sampler, "MarginalSampler")
	expect_equal(sampler$label, "Permutation sampler")
	expect_s3_class(sampler$param_set, "ParamSet")
})

test_that("MarginalPermutationSampler sampling works", {
	task = tgen("circle", d = 5)$generate(n = 100)
	sampler = MarginalPermutationSampler$new(task)
	data = task$data()

	# Test single feature sampling
	sampled_data = sampler$sample("x1")

	expect_sampler_output_structure(sampled_data, task, nrows = 100)
	expect_feature_type_consistency(sampled_data, task)

	# Permuted values come from original distribution
	expect_setequal(sampled_data$x1, data$x1)

	# Other features unchanged
	expect_non_sampled_unchanged(sampled_data, data, c("x2", "x3", "x4", "x5"))
})

test_that("MarginalPermutationSampler handles multiple features", {
	task = tgen("circle", d = 5)$generate(n = 100)
	sampler = MarginalPermutationSampler$new(task)
	data = task$data()

	features = c("x1", "x2", "x3")
	sampled_data = sampler$sample(features)

	expect_sampler_output_structure(sampled_data, task, nrows = 100)
	expect_feature_type_consistency(sampled_data, task)

	# Permuted values come from original distribution
	for (feat in features) {
		expect_setequal(sampled_data[[feat]], data[[feat]])
	}

	# Non-sampled features unchanged
	expect_non_sampled_unchanged(sampled_data, data, c("x4", "x5"))
})

test_that("MarginalPermutationSampler works with different task types", {
	# Regression task
	task_regr = tgen("circle", d = 4)$generate(n = 100)
	sampler_regr = MarginalPermutationSampler$new(task_regr)
	sampled_regr = sampler_regr$sample("x1")
	expect_sampler_output_structure(sampled_regr, task_regr, nrows = 100)

	# Binary classification task
	task_classif = tsk("sonar")
	sampler_classif = MarginalPermutationSampler$new(task_classif)
	sampled_classif = sampler_classif$sample("V1")
	expect_sampler_output_structure(sampled_classif, task_classif, nrows = task_classif$nrow)

	# Multiclass classification task
	task_multi = tsk("iris")
	sampler_multi = MarginalPermutationSampler$new(task_multi)
	sampled_multi = sampler_multi$sample("Sepal.Length")
	expect_sampler_output_structure(sampled_multi, task_multi, nrows = 150)
})

test_that("MarginalPermutationSampler preserves feature types", {
	test_sampler_feature_types(MarginalPermutationSampler)
})

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.