tests/testthat/test-stack_broken.R

context("test-stack_broken.R -- Stack robustness to sub-learner errors")
library(R6)

data(cpp_imputed)
covars <- c(
  "apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs",
  "sexn"
)
outcome <- "haz"
task <- sl3_Task$new(cpp_imputed, covariates = covars, outcome = outcome)

library(uuid)
Lrnr_broken <- R6Class(
  classname = "Lrnr_broken", inherit = Lrnr_base, portable = TRUE,
  public = list(
    initialize = function(pbreak = 1) {
      private$.params <- list(pbreak = pbreak)

      invisible(self)
    }
  ),
  private = list(
    .train = function(task) {
      if (rbinom(1, 1, self$params$pbreak)) {
        stop("this Learner often returns an error on training")
      }
    }
  )
)

broken_learner <- Lrnr_broken$new()
glm_learner <- Lrnr_glm$new()
broken_stack <- Stack$new(broken_learner, glm_learner)

test_that("Stack produces warning for learners that return errors", {
  expect_warning(broken_fit <- broken_stack$train(task))
})

test_that("Stack predicts on remaining good learners", {
  broken_fit <- suppressWarnings(broken_stack$train(task))
  predictions <- broken_fit$predict()
  expect_equal(length(predictions[[1]]), nrow(cpp_imputed))
  # expect_equal(names(predictions), "Lrnr_glm_TRUE")
})

test_that("Stack fails if all learners return errors", {
  all_broken_stack <- Stack$new(broken_learner, broken_learner)
  expect_error(suppressWarnings(all_broken_stack$train(task)))
})

test_that("Lrnr_cv on stack warns when learners that error on any fold", {
  broken_cv <- Lrnr_cv$new(broken_stack)
  expect_warning(broken_cv$train(task))
})

test_that("Lrnr_cv on stack drops all learners that error on any fold", {
  broken_cv <- Lrnr_cv$new(broken_stack)
  broken_cv_fit <- suppressWarnings(broken_cv$train(task))
  cv_preds <- broken_cv_fit$predict()
  expect_equal(length(cv_preds[[1]]), nrow(cpp_imputed))
  # expect_equal(names(cv_preds), "Lrnr_glm_TRUE")
})

test_that("Lrnr_sl propagates errors to full refit", {
  broken_stack <- Stack$new(broken_learner, glm_learner, Lrnr_glmnet$new())
  broken_sl <- Lrnr_sl$new(broken_stack, glm_learner)
  broken_sl_fit <- suppressWarnings(broken_sl$train(task))
  sl_preds <- broken_sl_fit$predict()
  expect_equal(length(sl_preds), nrow(task$data))
})
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.