Nothing
# =============================================================================
# Tests for pretrained learner support
# =============================================================================
# -----------------------------------------------------------------------------
# assert_pretrained() unit tests
# -----------------------------------------------------------------------------
test_that("assert_pretrained returns FALSE for untrained learner", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
resampling = rsmp("holdout")$instantiate(task)
expect_false(assert_pretrained(learner, task, resampling))
})
test_that("assert_pretrained returns TRUE for valid pretrained setup", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
resampling = rsmp("holdout")$instantiate(task)
expect_true(assert_pretrained(learner, task, resampling))
})
test_that("assert_pretrained errors when trained learner has multi-fold resampling", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
resampling = rsmp("cv", folds = 3)$instantiate(task)
expect_error(assert_pretrained(learner, task, resampling), "not compatible")
})
# -----------------------------------------------------------------------------
# PFI with pretrained learner
# -----------------------------------------------------------------------------
test_that("PFI works with pretrained learner", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("regr.mse")
learner$train(task)
pfi = PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
n_repeats = 3L
)
pfi$compute()
expect_importance_dt(pfi$importance(), features = pfi$features)
checkmate::expect_r6(pfi$resample_result, "ResampleResult")
# Original learner must still have its model after compute
expect_false(is.null(learner$model))
})
test_that("PFI pretrained and non-pretrained produce comparable results", {
set.seed(123)
task = tsk("mtcars")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("regr.mse")
# Pretrained path
learner_pre = lrn("regr.rpart")
learner_pre$train(task)
set.seed(1)
pfi_pre = PFI$new(
task = task,
learner = learner_pre,
measure = measure,
resampling = resampling,
n_repeats = 5L
)
pfi_pre$compute()
# Non-pretrained path (fresh learner, same resampling)
set.seed(1)
pfi_fresh = PFI$new(
task = task,
learner = lrn("regr.rpart"),
measure = measure,
resampling = resampling,
n_repeats = 5L
)
pfi_fresh$compute()
# Both should produce valid importance tables with same features
expect_importance_dt(pfi_pre$importance(), features = pfi_pre$features)
expect_importance_dt(pfi_fresh$importance(), features = pfi_fresh$features)
expect_setequal(pfi_pre$importance()$feature, pfi_fresh$importance()$feature)
})
test_that("PFI with pretrained learner errors on non-instantiated resampling", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
measure = msr("regr.mse")
expect_error(
PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("holdout")
),
"instantiated"
)
})
test_that("PFI with pretrained learner errors on multi-fold resampling at construction", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
measure = msr("regr.mse")
expect_error(
PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("cv", folds = 3)$instantiate(task),
n_repeats = 2L
),
"not compatible"
)
})
test_that("PFI with pretrained learner works for classification", {
task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("classif.ce")
learner$train(task)
pfi = PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
n_repeats = 3L
)
pfi$compute()
expect_importance_dt(pfi$importance(), features = pfi$features)
})
# -----------------------------------------------------------------------------
# MarginalSAGE with pretrained learner
# -----------------------------------------------------------------------------
test_that("MarginalSAGE works with pretrained learner", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("regr.mse")
learner$train(task)
sage = MarginalSAGE$new(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
checkmate::expect_r6(sage$resample_result, "ResampleResult")
})
test_that("MarginalSAGE pretrained and non-pretrained produce comparable results", {
set.seed(123)
task = tsk("mtcars")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("regr.mse")
# Pretrained path
learner_pre = lrn("regr.rpart")
learner_pre$train(task)
set.seed(1)
sage_pre = MarginalSAGE$new(
task = task,
learner = learner_pre,
measure = measure,
resampling = resampling,
n_permutations = 3L,
n_samples = 20L
)
sage_pre$compute()
# Non-pretrained path
set.seed(1)
sage_fresh = MarginalSAGE$new(
task = task,
learner = lrn("regr.rpart"),
measure = measure,
resampling = resampling,
n_permutations = 3L,
n_samples = 20L
)
sage_fresh$compute()
expect_importance_dt(sage_pre$importance(), features = sage_pre$features)
expect_importance_dt(sage_fresh$importance(), features = sage_fresh$features)
expect_setequal(sage_pre$importance()$feature, sage_fresh$importance()$feature)
})
test_that("MarginalSAGE with pretrained learner errors on non-instantiated resampling", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
measure = msr("regr.mse")
expect_error(
MarginalSAGE$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("holdout"),
n_permutations = 2L,
n_samples = 20L
),
"instantiated"
)
})
test_that("MarginalSAGE with pretrained learner errors on multi-fold resampling at construction", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
measure = msr("regr.mse")
expect_error(
MarginalSAGE$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("cv", folds = 3)$instantiate(task),
n_permutations = 2L,
n_samples = 20L
),
"not compatible"
)
})
test_that("MarginalSAGE with pretrained learner works for classification", {
task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
resampling = rsmp("holdout")$instantiate(task)
measure = msr("classif.ce")
learner$train(task)
sage = MarginalSAGE$new(
task = task,
learner = learner,
measure = measure,
resampling = resampling,
n_permutations = 2L,
n_samples = 20L
)
sage$compute()
expect_importance_dt(sage$importance(), features = sage$features)
})
# -----------------------------------------------------------------------------
# LOCO warns for pretrained learner (refit-based methods)
# -----------------------------------------------------------------------------
test_that("LOCO warns when given a pretrained learner", {
task = tsk("mtcars")
learner = lrn("regr.rpart")
learner$train(task)
expect_warning(
LOCO$new(
task = task,
learner = learner,
measure = msr("regr.mse")
),
"already trained"
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.