tests/testthat/test-sl_fold.R

library(testthat)
context("test_sl.R -- Basic Lrnr_sl functionality")

library(sl3)
library(origami)
library(SuperLearner)

test_that("sl prediction for each fold works", {
  data(cpp_imputed)
  covars <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
  outcome <- "haz"
  task <- sl3_Task$new(data.table::copy(cpp_imputed), covariates = covars, outcome = outcome)

  glm_learner <- Lrnr_glm$new()
  glmnet_learner <- Lrnr_pkg_SuperLearner$new("SL.glmnet")
  subset_apgar <- Lrnr_subset_covariates$new(covariates = c("apgar1", "apgar5"))
  learners <- list(glm_learner, glmnet_learner, subset_apgar)
  sl1 <- make_learner(Lrnr_sl, learners, glm_learner)

  sl_fit <- sl1$train(task)

  fold1_predict <- sl_fit$predict_fold(task, 1)
  validation_predict <- sl_fit$predict_fold(task, "validation")
  expect_false(all(fold1_predict == validation_predict))
  expect_true(any(fold1_predict == validation_predict))

  glm_fit <- glm_learner$train(task)
  expect_warning(glm_fold1_predict <- glm_fit$predict_fold(task, 1))
})
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.