tests/testthat/test-xgboost.R

library(parsnip)

us_deaths$off <- log(us_deaths$population)

us_deaths2 <- recipes::recipe(~ age_group + gender + year + off, us_deaths) |>
  recipes::step_dummy(age_group, gender, one_hot = TRUE) |>
  recipes::prep() |>
  recipes::juice()

x <- as.matrix(us_deaths2)

xgtrain <- xgboost::xgb.DMatrix(x[, colnames(x) != "off"],
                                label = us_deaths$deaths,
                                base_margin = us_deaths$off)

set.seed(42)
mod <- xgboost::xgb.train(
  params = list(
    objective  = "count:poisson",
    eval_metric = "rmse",
    eta = 1,
    subsample = 1,
    colsample_bynode = 1,
    min_child_weight = 1,
    max_depth = 2
  ),
  data = xgtrain,
  nrounds = 25
)

mod2 <- xgb_train_offset(x,
                         us_deaths$deaths, "off",
                         eta = 1, subsample = 1, colsample_bynode = 1,
                         min_child_weight = 1, max_depth = 2, nrounds = 25,
                         counts = FALSE)

test_that("xgb_train_offset matches xgboost", {
  expect_equal(predict(mod, xgtrain), predict(mod2, xgtrain))
  expect_equal(predict(mod, xgtrain), xgb_predict_offset(mod2, xgtrain))
  expect_equal(predict(mod, xgtrain), xgb_predict_offset(mod2, x, "off"))
})

test_that("xgb_train_offset throws the correct errors and warnings", {
  expect_error(xgb_train_offset(x, us_deaths$deaths, "off", validation = -1),
               regexp = "`validation` should be")
  expect_error(xgb_train_offset(x, us_deaths$deaths, "off", early_stop = 1),
               regexp = "`early_stop` should be")
  expect_warning(xgb_train_offset(x, us_deaths$deaths, "off", early_stop = 99),
                 regexp = "`early_stop` was reduced to")
  expect_error(xgb_train_offset(x, us_deaths$deaths, "off", subsample = 1.01),
               regexp = "`subsample` should be")
  expect_warning(xgb_train_offset(x, us_deaths$deaths, "off",
                                  min_child_weight = 1E3),
                 regexp = "1000 samples were requested")
  expect_error(xgb_train_offset(x, us_deaths$deaths),
               regexp = "A column named `offset` must be present")
  expect_error(xgb_train_offset(x, us_deaths$deaths, "off",
                                colsample_bynode = 0.5),
               regexp = "Please use a value >= 1")
  expect_warning(xgb_train_offset(x, us_deaths$deaths, "off",
                                  objective = "reg:squarederror"),
                 regexp = "The following arguments are guarded")
  expect_warning(xgb_train_offset(x, us_deaths$deaths, "off",
                                  params = list(eta = 1)),
                 regexp = "Please supply elements of the `params` list")
  expect_error(xgb_predict_offset(mod2, xgboost::xgb.DMatrix(x)),
               regexp = "If `new_data` is an `xgb.DMatrix`,")
})

# standard formula for testing
f <- deaths ~ age_group + gender + year + off

test_that("boost_tree_offset() works", {

  xgb_off <- boost_tree_offset(learn_rate = 1,
                               sample_size = 1,
                               mtry = 11,
                               min_n = 1,
                               tree_depth = 2,
                               trees = 25) |>
    set_engine("xgboost_offset", offset_col = "off") |>
    fit(f, data = us_deaths)
  expect_identical(predict(mod, xgtrain),
                   predict(xgb_off, us_deaths)$.pred)
  expect_identical(predict(mod, xgtrain),
                   predict(xgb_off, us_deaths, type = "raw"))

})

rec <- recipes::recipe(deaths ~ age_group + gender + year + off, us_deaths) |>
  recipes::step_dummy(age_group, gender, one_hot = TRUE) |>
  recipes::step_rename(offset = off)

test_that("boost_tree_offset() works with recipes", {

  # rpart_exposure
  xgb_off <- workflows::workflow() |>
    workflows::add_recipe(rec) |>
    workflows::add_model(boost_tree_offset(learn_rate = 1,
                                           sample_size = 1,
                                           mtry = 11,
                                           min_n = 1,
                                           tree_depth = 2,
                                           trees = 25) |>
                           set_engine("xgboost_offset")) |>
    fit(data = us_deaths)
  expect_identical(predict(mod, xgtrain),
                   predict(xgb_off, us_deaths)$.pred)

})

test_that("finalize works", {

  mod_spec <- boost_tree_offset(mtry = tune(),
                                trees = tune(),
                                min_n = tune(),
                                tree_depth = tune(),
                                learn_rate = tune(),
                                loss_reduction = tune(),
                                sample_size = tune(),
                                stop_iter = tune()) |>
    set_engine("xgboost_offset")

  wf <- workflows::workflow() |>
    workflows::add_model(mod_spec) |>
    workflows::add_recipe(rec)

  param_grid <- data.frame(mtry = 4,
                           trees = 11,
                           min_n = 2,
                           tree_depth = 3,
                           learn_rate = 0.3,
                           loss_reduction = 0,
                           sample_size = 0.7,
                           stop_iter = 7)

  expect_no_error(tune::finalize_workflow(wf, param_grid) |> fit(us_deaths))

  expect_equal(tune::finalize_model(mod_spec, param_grid)$args |>
                 lapply(rlang::eval_tidy),
               as.list(param_grid))

})

Try the offsetreg package in your browser

Any scripts or data that you put into this service are public.

offsetreg documentation built on April 4, 2025, 1:35 a.m.