tests/testthat/test-importance.R

library(testthat)
context("test_importance.R -- Variable Importance")

library(sl3)
library(origami)
library(data.table)
data(cpp_imputed)
setDT(cpp_imputed)
cpp_imputed[, parity_cat := factor(ifelse(parity < 4, parity, 4))]
covars <- c("apgar1", "parity_cat", "sexn")
outcome <- "haz"

task <- sl3_Task$new(cpp_imputed,
  covariates = covars, outcome = outcome,
  folds = origami::make_folds(cpp_imputed, V = 3)
)

lrnr_glmnet <- make_learner(Lrnr_glmnet, nfolds = 3)
lrnr_mean <- make_learner(Lrnr_mean)
lrnr_glm <- make_learner(Lrnr_glm)
learners <- make_learner(Stack, lrnr_glmnet, lrnr_mean, lrnr_glm)
sl <- make_learner(Lrnr_sl, learners)
test_that("sl3 importance fails when fit isn't trained", {
  expect_error(importance(sl))
})
fit <- sl$train(task)

# Ensure various implementations of sl3 importance run
remove_validation_risk_ratio <- importance(fit)
remove_full_risk_ratio <- importance(fit, fold_number = "full")
remove_validation_risk_diff <- importance(fit, importance_metric = "difference")
remove_full_risk_diff <- importance(fit,
  fold_number = "full",
  importance_metric = "difference"
)
remove_fold1_risk_ratio <- importance(fit, fold_number = 1)

permute_validation_risk_ratio <- importance(fit, type = "permute")
permute_full_risk_ratio <- importance(fit, type = "permute", fold_number = "full")
permute_validation_risk_diff <- importance(fit,
  type = "permute",
  importance_metric = "difference"
)
permute_full_risk_diff <- importance(fit,
  type = "permute", fold_number = "full",
  importance_metric = "difference"
)
permute_fold1_risk_ratio <- importance(fit, type = "permute", fold_number = 1)
########## test covariate groups
covars <- c(
  "apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn"
)
task <- sl3_Task$new(cpp_imputed,
  covariates = covars, outcome = "haz",
  folds = origami::make_folds(cpp_imputed, V = 3)
)
fit <- sl$train(task)
groups <- list(
  scores = c("apgar1", "apgar5"),
  maternal = c("parity", "mage", "meducyrs")
)
varimp <- importance(fit, covariate_groups = groups)
varimp_permute <- importance(fit, type = "permute", covariate_groups = groups)
varimp_fold1 <- importance(fit, covariate_groups = groups, fold_number = 1)
varimp_permute_fold1 <- importance(
  fit,
  type = "permute", covariate_groups = groups, fold_number = 1
)
test_that("sl3 importance fails when groups with > 1 covariate unnamed", {
  names(groups)[1] <- ""
  expect_error(importance(fit, covariate_groups = groups))
})

test_that("sl3 importance fails when groups don't contain covariates", {
  groups <- c(groups, list("not_a_covariate" = "not_a_covariate"))
  expect_error(importance(fit, covariate_groups = groups))
})
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.