tests/testthat/test-hal_multivariate.R

context("Multivariate outcome prediction with HAL")

library(glmnet)
data(MultiGaussianExample)

# get hal fit
set.seed(74296)
hal_fit <- fit_hal(
  X = MultiGaussianExample$x, Y = MultiGaussianExample$y, family = "mgaussian",
  return_x_basis = TRUE
)
hal_summary <- summary(hal_fit)

test_that("HAL and glmnet predictions match for multivariate outcome", {
  # get hal preds
  hal_pred <- predict(hal_fit, new_data = MultiGaussianExample$x)
  # get glmnet preds
  set.seed(74296)
  glmnet_fit <- cv.glmnet(
    x = hal_fit$x_basis, y = MultiGaussianExample$y,
    family = "mgaussian", standardize = FALSE,
    lambda.min.ratio = 1e-4
  )
  glmnet_pred <- predict(glmnet_fit, hal_fit$x_basis, s = hal_fit$lambda_star)[, , 1]
  # test equivalence
  colnames(glmnet_pred) <- colnames(hal_pred)
  expect_equivalent(glmnet_pred, hal_pred)
})

test_that("HAL summarizes coefs for each multivariate outcome prediction", {
  expect_equal(ncol(MultiGaussianExample$y), length(hal_summary))
})

test_that("HAL summarizes coefs appropriately for multivariate outcome", {
  # this checks intercept matches
  lapply(seq_along(hal_summary), function(i) {
    expect_equal(hal_fit$coefs[[i]][1, ], as.numeric(hal_summary[[i]]$table[1, 1]))
  })
})

test_that("Error when prediction_bounds is incorrectly formatted", {
  fit_control <- list(prediction_bounds = 9)
  expect_error(fit_hal(
    X = MultiGaussianExample$x, Y = MultiGaussianExample$y,
    family = "mgaussian", fit_control = fit_control
  ))
})

test_that("HAL summary for multivariate outcome predictions prints", {
  hal_summary2 <- summary(hal_fit, only_nonzero_coefs = FALSE)
  expect_output(print(hal_summary, length = 2))
  expect_output(print(hal_summary))
  expect_output(print(hal_summary2, length = 2))
  expect_output(print(hal_summary2))
})
jeremyrcoyle/mangolassi documentation built on Nov. 18, 2023, 6:22 p.m.