Nothing
context("Fits and prediction of SuperLearner package.")
library(SuperLearner)
# easily compute MSE
mse <- function(preds, y) {
mean((preds - y)^2)
}
# simulation constants
set.seed(479512)
p <- 3 # dimensionality
n <- 100 # observations
# simulate data
x <- as.data.frame(replicate(p, rnorm(n)))
y <- sin(1 / x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
test_x <- as.data.frame(replicate(p, rnorm(n)))
test_y <- sin(1 / test_x[, 2]) + cos(test_x[, 3]) +
rnorm(n, mean = 0, sd = 0.2)
# run HAL by itself
hal <- fit_hal(X = x, Y = y, yolo = FALSE)
pred_hal_train <- predict(hal, new_data = x)
pred_hal_test <- predict(hal, new_data = test_x)
# run SL-classic with glmnet and get predictions
hal_sl <- SuperLearner(Y = y, X = x, SL.lib = "SL.hal9001")
sl_hal_fit <- SL.hal9001(
Y = y, X = x, newX = NULL,
family = stats::gaussian(),
obsWeights = rep(1, length(y)),
id = seq_along(y)
)
# hal9001:::predict.SL.hal9001(sl_hal_fit$fit,newX=x,newdata=x)
pred_hal_sl_train <- as.numeric(predict(hal_sl, newX = x)$pred)
pred_hal_sl_test <- as.numeric(predict(hal_sl, newX = test_x)$pred)
# run an SL with HAL and some parametric learners
sl <- SuperLearner(
Y = y, X = x, SL.lib = c("SL.mean", "SL.hal9001"),
cvControl = list(validRows = hal_sl$validRows)
)
# test for HAL vs. SL-HAL: outputs are the same length
test_that("HAL and SuperLearner-HAL produce results of same shape", {
expect_equal(length(pred_hal_train), length(pred_hal_sl_train))
expect_equal(length(pred_hal_test), length(pred_hal_sl_test))
})
# test of MSEs being close: SL-HAL and SL dominated by HAL should be very close
# (hence the rather low tolerance, esp. given an additive scale)
test_that("HAL dominates other algorithms when used in SuperLearner", {
pred_sl_test <- as.numeric(predict(sl, newX = test_x)$pred)
expect_equal(
mse(pred_sl_test, test_y),
expected = mse(pred_hal_sl_test, test_y),
scale = mse(pred_hal_sl_test, test_y), tolerance = 0.05
)
})
# test of SL-HAL risk: HAL has lowest CV-risk in the learner library
test_that("HAL has the lowest CV-risk amongst algorithms in Super Learner", {
expect_equivalent(names(which.min(sl$cvRisk)), "SL.hal9001_All")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.