tests/testthat/test-pretrained.R

# =============================================================================
# Tests for pretrained learner support
# =============================================================================

# -----------------------------------------------------------------------------
# assert_pretrained() unit tests
# -----------------------------------------------------------------------------

test_that("assert_pretrained returns FALSE for untrained learner", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	resampling = rsmp("holdout")$instantiate(task)

	expect_false(assert_pretrained(learner, task, resampling))
})

test_that("assert_pretrained returns TRUE for valid pretrained setup", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	resampling = rsmp("holdout")$instantiate(task)

	expect_true(assert_pretrained(learner, task, resampling))
})

test_that("assert_pretrained errors when trained learner has multi-fold resampling", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	resampling = rsmp("cv", folds = 3)$instantiate(task)

	expect_error(assert_pretrained(learner, task, resampling), "not compatible")
})

# -----------------------------------------------------------------------------
# PFI with pretrained learner
# -----------------------------------------------------------------------------

test_that("PFI works with pretrained learner", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("regr.mse")

	learner$train(task)

	pfi = PFI$new(
		task = task,
		learner = learner,
		measure = measure,
		resampling = resampling,
		n_repeats = 3L
	)

	pfi$compute()
	expect_importance_dt(pfi$importance(), features = pfi$features)
	checkmate::expect_r6(pfi$resample_result, "ResampleResult")

	# Original learner must still have its model after compute
	expect_false(is.null(learner$model))
})

test_that("PFI pretrained and non-pretrained produce comparable results", {
	set.seed(123)
	task = tsk("mtcars")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("regr.mse")

	# Pretrained path
	learner_pre = lrn("regr.rpart")
	learner_pre$train(task)

	set.seed(1)
	pfi_pre = PFI$new(
		task = task,
		learner = learner_pre,
		measure = measure,
		resampling = resampling,
		n_repeats = 5L
	)
	pfi_pre$compute()

	# Non-pretrained path (fresh learner, same resampling)
	set.seed(1)
	pfi_fresh = PFI$new(
		task = task,
		learner = lrn("regr.rpart"),
		measure = measure,
		resampling = resampling,
		n_repeats = 5L
	)
	pfi_fresh$compute()

	# Both should produce valid importance tables with same features
	expect_importance_dt(pfi_pre$importance(), features = pfi_pre$features)
	expect_importance_dt(pfi_fresh$importance(), features = pfi_fresh$features)
	expect_setequal(pfi_pre$importance()$feature, pfi_fresh$importance()$feature)
})

test_that("PFI with pretrained learner errors on non-instantiated resampling", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	measure = msr("regr.mse")

	expect_error(
		PFI$new(
			task = task,
			learner = learner,
			measure = measure,
			resampling = rsmp("holdout")
		),
		"instantiated"
	)
})

test_that("PFI with pretrained learner errors on multi-fold resampling at construction", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	measure = msr("regr.mse")

	expect_error(
		PFI$new(
			task = task,
			learner = learner,
			measure = measure,
			resampling = rsmp("cv", folds = 3)$instantiate(task),
			n_repeats = 2L
		),
		"not compatible"
	)
})

test_that("PFI with pretrained learner works for classification", {
	task = tsk("penguins")
	learner = lrn("classif.rpart", predict_type = "prob")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("classif.ce")

	learner$train(task)

	pfi = PFI$new(
		task = task,
		learner = learner,
		measure = measure,
		resampling = resampling,
		n_repeats = 3L
	)

	pfi$compute()
	expect_importance_dt(pfi$importance(), features = pfi$features)
})

# -----------------------------------------------------------------------------
# MarginalSAGE with pretrained learner
# -----------------------------------------------------------------------------

test_that("MarginalSAGE works with pretrained learner", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("regr.mse")

	learner$train(task)

	sage = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		resampling = resampling,
		n_permutations = 2L,
		n_samples = 20L
	)

	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
	checkmate::expect_r6(sage$resample_result, "ResampleResult")
})

test_that("MarginalSAGE pretrained and non-pretrained produce comparable results", {
	set.seed(123)
	task = tsk("mtcars")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("regr.mse")

	# Pretrained path
	learner_pre = lrn("regr.rpart")
	learner_pre$train(task)

	set.seed(1)
	sage_pre = MarginalSAGE$new(
		task = task,
		learner = learner_pre,
		measure = measure,
		resampling = resampling,
		n_permutations = 3L,
		n_samples = 20L
	)
	sage_pre$compute()

	# Non-pretrained path
	set.seed(1)
	sage_fresh = MarginalSAGE$new(
		task = task,
		learner = lrn("regr.rpart"),
		measure = measure,
		resampling = resampling,
		n_permutations = 3L,
		n_samples = 20L
	)
	sage_fresh$compute()

	expect_importance_dt(sage_pre$importance(), features = sage_pre$features)
	expect_importance_dt(sage_fresh$importance(), features = sage_fresh$features)
	expect_setequal(sage_pre$importance()$feature, sage_fresh$importance()$feature)
})

test_that("MarginalSAGE with pretrained learner errors on non-instantiated resampling", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	measure = msr("regr.mse")

	expect_error(
		MarginalSAGE$new(
			task = task,
			learner = learner,
			measure = measure,
			resampling = rsmp("holdout"),
			n_permutations = 2L,
			n_samples = 20L
		),
		"instantiated"
	)
})

test_that("MarginalSAGE with pretrained learner errors on multi-fold resampling at construction", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)
	measure = msr("regr.mse")

	expect_error(
		MarginalSAGE$new(
			task = task,
			learner = learner,
			measure = measure,
			resampling = rsmp("cv", folds = 3)$instantiate(task),
			n_permutations = 2L,
			n_samples = 20L
		),
		"not compatible"
	)
})

test_that("MarginalSAGE with pretrained learner works for classification", {
	task = tsk("penguins")
	learner = lrn("classif.rpart", predict_type = "prob")
	resampling = rsmp("holdout")$instantiate(task)
	measure = msr("classif.ce")

	learner$train(task)

	sage = MarginalSAGE$new(
		task = task,
		learner = learner,
		measure = measure,
		resampling = resampling,
		n_permutations = 2L,
		n_samples = 20L
	)

	sage$compute()
	expect_importance_dt(sage$importance(), features = sage$features)
})

# -----------------------------------------------------------------------------
# LOCO warns for pretrained learner (refit-based methods)
# -----------------------------------------------------------------------------

test_that("LOCO warns when given a pretrained learner", {
	task = tsk("mtcars")
	learner = lrn("regr.rpart")
	learner$train(task)

	expect_warning(
		LOCO$new(
			task = task,
			learner = learner,
			measure = msr("regr.mse")
		),
		"already trained"
	)
})

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.