Nothing
context("Multivariate outcome prediction with HAL")
library(glmnet)
data(MultiGaussianExample)
# get hal fit
set.seed(74296)
hal_fit <- fit_hal(
X = MultiGaussianExample$x, Y = MultiGaussianExample$y, family = "mgaussian",
return_x_basis = TRUE
)
hal_summary <- summary(hal_fit)
test_that("HAL and glmnet predictions match for multivariate outcome", {
# get hal preds
hal_pred <- predict(hal_fit, new_data = MultiGaussianExample$x)
# get glmnet preds
set.seed(74296)
glmnet_fit <- cv.glmnet(
x = hal_fit$x_basis, y = MultiGaussianExample$y,
family = "mgaussian", standardize = FALSE,
lambda.min.ratio = 1e-4
)
glmnet_pred <- predict(glmnet_fit, hal_fit$x_basis, s = hal_fit$lambda_star)[, , 1]
# test equivalence
colnames(glmnet_pred) <- colnames(hal_pred)
expect_equivalent(glmnet_pred, hal_pred)
})
test_that("HAL summarizes coefs for each multivariate outcome prediction", {
expect_equal(ncol(MultiGaussianExample$y), length(hal_summary))
})
test_that("HAL summarizes coefs appropriately for multivariate outcome", {
# this checks intercept matches
lapply(seq_along(hal_summary), function(i) {
expect_equal(hal_fit$coefs[[i]][1, ], as.numeric(hal_summary[[i]]$table[1, 1]))
})
})
test_that("Error when prediction_bounds is incorrectly formatted", {
fit_control <- list(prediction_bounds = 9)
expect_error(fit_hal(
X = MultiGaussianExample$x, Y = MultiGaussianExample$y,
family = "mgaussian", fit_control = fit_control
))
})
test_that("HAL summary for multivariate outcome predictions prints", {
hal_summary2 <- summary(hal_fit, only_nonzero_coefs = FALSE)
expect_output(print(hal_summary, length = 2))
expect_output(print(hal_summary))
expect_output(print(hal_summary2, length = 2))
expect_output(print(hal_summary2))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.