tests/testthat/test-SampleSplitSuperLearner.R

library(SuperLearner)
if(all(sapply(c("testthat", "randomForest", "glmnet"), requireNamespace))) {

########################
# Simulate data
set.seed(23432)
# Create small training set for quick tests.
n <- 80
p <- 6
X <- data.frame(matrix(rnorm(n * p), nrow = n, ncol = p))
colnames(X) <- paste0("X", 1:p)
gen_y = function(x) x[, 1] + sqrt(abs(x[, 2] * x[, 3])) + x[, 2] - x[, 3] + rnorm(nrow(x))
Y <- gen_y(X)

########################
# Run the SampleSplitSuperLearner
SL.library = c("SL.randomForest", "SL.glmnet", "SL.glm", "SL.mean")

# default split: 80% in training, 20% in validation
defaultSplit = formals(SampleSplitSuperLearner)$split
sssl1 = SampleSplitSuperLearner(Y = Y, X = X, SL.library = SL.library)
testthat::expect_equal(length(sssl1$validRows), (1 - defaultSplit) * n)
testthat::expect_false(any(sssl1$errorsInLibrary))
testthat::expect_false(any(sssl1$errorsInCVLibrary))

# split by specifying row numbers in validation set
validRows = seq(from = round(n/5), to = n, by = round(n/10))
sssl2 = SampleSplitSuperLearner(Y = Y, X = X, SL.library = SL.library,
								split = validRows)
testthat::expect_equal(sssl2$validRows, validRows)
testthat::expect_false(any(sssl2$errorsInLibrary))
testthat::expect_false(any(sssl2$errorsInCVLibrary))

# split with only one observation in validation set
validRows = 1
sssl3 = SampleSplitSuperLearner(Y = Y, X = X, SL.library = SL.library,
								split = validRows)
testthat::expect_equal(sssl3$validRows, validRows)
testthat::expect_false(any(sssl3$errorsInLibrary))
testthat::expect_false(any(sssl3$errorsInCVLibrary))
	
}

Try the SuperLearner package in your browser

Any scripts or data that you put into this service are public.

SuperLearner documentation built on July 26, 2023, 6:05 p.m.