tests/testthat/test-general_families.R

context("HAL with general familes.")
set.seed(45791)

# easily compute MSE
mse <- function(preds, y) {
  mean((preds - y)^2)
}

# generate simple test data
n <- 100
p <- 3
x <- xmat <- matrix(rnorm(n * p), n, p)
y_prob <- plogis(3 * sin(x[, 1]) + sin(x[, 2]))
y <- rbinom(n = n, size = 1, prob = y_prob)

test_n <- 100
test_x <- matrix(rnorm(test_n * p), test_n, p)
test_y_prob <- plogis(3 * sin(test_x[, 1]) + sin(test_x[, 2]))
test_y <- rbinom(n = test_n, size = 1, prob = y_prob)
fit_control <- list(prediction_bounds = c(0.01, 0.99))
# ml implementation
ml_hal_fit <- suppressWarnings(
  fit_hal(X = x, Y = y, family = "binomial", fit_control = fit_control)
)
ml_hal_fit$times
x_basis <- make_design_matrix(x, ml_hal_fit$basis_list)

# training sample prediction
preds <- predict(ml_hal_fit, new_data = x)
ml_hal_mse1 <- mse(preds, y_prob)
set.seed(45791)

ml_hal_fit <- suppressWarnings(
  fit_hal(X = x, Y = y, family = binomial(), fit_control = fit_control)
)
ml_hal_fit$times
x_basis <- make_design_matrix(x, ml_hal_fit$basis_list)

# training sample prediction
preds <- predict(ml_hal_fit, new_data = x)
ml_hal_mse2 <- mse(preds, y_prob)

test_that("MSE for logistic regression close to logistic family object pred", {
  expect_true(abs(ml_hal_mse1 - ml_hal_mse2) < 0.01)
})

# ml implementation
ml_hal_fit <- suppressWarnings(fit_hal(X = x, Y = y, family = "poisson"))
ml_hal_fit$times
x_basis <- make_design_matrix(x, ml_hal_fit$basis_list)

# training sample prediction
preds <- predict(ml_hal_fit, new_data = x)
ml_hal_mse1 <- mse(preds, y_prob)
set.seed(45791)

ml_hal_fit <- suppressWarnings(fit_hal(X = x, Y = y, family = poisson()))
ml_hal_fit$times
x_basis <- make_design_matrix(x, ml_hal_fit$basis_list)

# training sample prediction
preds <- predict(ml_hal_fit, new_data = x)
ml_hal_mse2 <- mse(preds, y_prob)

test_that("MSE for logistic regression close to logistic family object pred", {
  expect_true(abs(ml_hal_mse1 - ml_hal_mse2) < 0.01)
})

test_that("Error when prediction_bounds is incorrectly formatted", {
  fit_control <- list(prediction_bounds = 9)
  expect_error(fit_hal(X = x, Y = y, fit_control = fit_control))
})

test_that("Message when standardize set to TRUE", {
  fit_control <- list(standardize = TRUE)
  expect_message(fit_hal(X = x, Y = y, fit_control = fit_control))
})

test_that("Warning when reduce_basis without zero-order smoothness", {
  expect_warning(fit_hal(X = x, Y = y, reduce_basis = 0.95))
})

Try the hal9001 package in your browser

Any scripts or data that you put into this service are public.

hal9001 documentation built on Nov. 14, 2023, 5:08 p.m.