library(testthat)
context("test_sl.R -- Basic Lrnr_sl functionality")
library(sl3)
library(origami)
library(SuperLearner)
test_that("sl prediction for each fold works", {
data(cpp_imputed)
covars <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
outcome <- "haz"
task <- sl3_Task$new(data.table::copy(cpp_imputed), covariates = covars, outcome = outcome)
glm_learner <- Lrnr_glm$new()
glmnet_learner <- Lrnr_pkg_SuperLearner$new("SL.glmnet")
subset_apgar <- Lrnr_subset_covariates$new(covariates = c("apgar1", "apgar5"))
learners <- list(glm_learner, glmnet_learner, subset_apgar)
sl1 <- make_learner(Lrnr_sl, learners, glm_learner)
sl_fit <- sl1$train(task)
fold1_predict <- sl_fit$predict_fold(task, 1)
validation_predict <- sl_fit$predict_fold(task, "validation")
expect_false(all(fold1_predict == validation_predict))
expect_true(any(fold1_predict == validation_predict))
glm_fit <- glm_learner$train(task)
expect_warning(glm_fold1_predict <- glm_fit$predict_fold(task, 1))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.