tests/testthat/test-sl_ecpolley.R

context("Fits and prediction of SuperLearner package.")
library(SuperLearner)

# easily compute MSE
mse <- function(preds, y) {
  mean((preds - y)^2)
}

# simulation constants
set.seed(479512)
p <- 3 # dimensionality
n <- 100 # observations

# simulate data
x <- as.data.frame(replicate(p, rnorm(n)))
y <- sin(1 / x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
test_x <- as.data.frame(replicate(p, rnorm(n)))
test_y <- sin(1 / test_x[, 2]) + cos(test_x[, 3]) +
  rnorm(n, mean = 0, sd = 0.2)

# run HAL by itself
hal <- fit_hal(X = x, Y = y, yolo = FALSE)
pred_hal_train <- predict(hal, new_data = x)
pred_hal_test <- predict(hal, new_data = test_x)

# run SL-classic with glmnet and get predictions
hal_sl <- SuperLearner(Y = y, X = x, SL.lib = "SL.hal9001")
sl_hal_fit <- SL.hal9001(
  Y = y, X = x, newX = NULL,
  family = stats::gaussian(),
  obsWeights = rep(1, length(y)),
  id = seq_along(y)
)
# hal9001:::predict.SL.hal9001(sl_hal_fit$fit,newX=x,newdata=x)
pred_hal_sl_train <- as.numeric(predict(hal_sl, newX = x)$pred)
pred_hal_sl_test <- as.numeric(predict(hal_sl, newX = test_x)$pred)

# run an SL with HAL and some parametric learners
sl <- SuperLearner(
  Y = y, X = x, SL.lib = c("SL.mean", "SL.hal9001"),
  cvControl = list(validRows = hal_sl$validRows)
)

# test for HAL vs. SL-HAL: outputs are the same length
test_that("HAL and SuperLearner-HAL produce results of same shape", {
  expect_equal(length(pred_hal_train), length(pred_hal_sl_train))
  expect_equal(length(pred_hal_test), length(pred_hal_sl_test))
})

# test of MSEs being close: SL-HAL and SL dominated by HAL should be very close
# (hence the rather low tolerance, esp. given an additive scale)
test_that("HAL dominates other algorithms when used in SuperLearner", {
  pred_sl_test <- as.numeric(predict(sl, newX = test_x)$pred)
  expect_equal(
    mse(pred_sl_test, test_y),
    expected = mse(pred_hal_sl_test, test_y),
    scale = mse(pred_hal_sl_test, test_y), tolerance = 0.05
  )
})

# test of SL-HAL risk: HAL has lowest CV-risk in the learner library
test_that("HAL has the lowest CV-risk amongst algorithms in Super Learner", {
  expect_equivalent(names(which.min(sl$cvRisk)), "SL.hal9001_All")
})

Try the hal9001 package in your browser

Any scripts or data that you put into this service are public.

hal9001 documentation built on Nov. 14, 2023, 5:08 p.m.