context("test-stack_broken.R -- Stack robustness to sub-learner errors")
library(R6)
data(cpp_imputed)
covars <- c(
"apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs",
"sexn"
)
outcome <- "haz"
task <- sl3_Task$new(cpp_imputed, covariates = covars, outcome = outcome)
library(uuid)
Lrnr_broken <- R6Class(
classname = "Lrnr_broken", inherit = Lrnr_base, portable = TRUE,
public = list(
initialize = function(pbreak = 1) {
private$.params <- list(pbreak = pbreak)
invisible(self)
}
),
private = list(
.train = function(task) {
if (rbinom(1, 1, self$params$pbreak)) {
stop("this Learner often returns an error on training")
}
}
)
)
broken_learner <- Lrnr_broken$new()
glm_learner <- Lrnr_glm$new()
broken_stack <- Stack$new(broken_learner, glm_learner)
test_that("Stack produces warning for learners that return errors", {
expect_warning(broken_fit <- broken_stack$train(task))
})
test_that("Stack predicts on remaining good learners", {
broken_fit <- suppressWarnings(broken_stack$train(task))
predictions <- broken_fit$predict()
expect_equal(length(predictions[[1]]), nrow(cpp_imputed))
# expect_equal(names(predictions), "Lrnr_glm_TRUE")
})
test_that("Stack fails if all learners return errors", {
all_broken_stack <- Stack$new(broken_learner, broken_learner)
expect_error(suppressWarnings(all_broken_stack$train(task)))
})
test_that("Lrnr_cv on stack warns when learners that error on any fold", {
broken_cv <- Lrnr_cv$new(broken_stack)
expect_warning(broken_cv$train(task))
})
test_that("Lrnr_cv on stack drops all learners that error on any fold", {
broken_cv <- Lrnr_cv$new(broken_stack)
broken_cv_fit <- suppressWarnings(broken_cv$train(task))
cv_preds <- broken_cv_fit$predict()
expect_equal(length(cv_preds[[1]]), nrow(cpp_imputed))
# expect_equal(names(cv_preds), "Lrnr_glm_TRUE")
})
test_that("Lrnr_sl propagates errors to full refit", {
broken_stack <- Stack$new(broken_learner, glm_learner, Lrnr_glmnet$new())
broken_sl <- Lrnr_sl$new(broken_stack, glm_learner)
broken_sl_fit <- suppressWarnings(broken_sl$train(task))
sl_preds <- broken_sl_fit$predict()
expect_equal(length(sl_preds), nrow(task$data))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.