tests/testthat/test-xgboost.R

library(parsnip)

us_deaths$off <- log(us_deaths$population)

us_deaths2 <- recipes::recipe(~ age_group + gender + year + off, us_deaths) |>
  recipes::step_dummy(age_group, gender, one_hot = TRUE) |>
  recipes::prep() |>
  recipes::juice()

x <- as.matrix(us_deaths2)

xgtrain <- xgboost::xgb.DMatrix(
  x[, colnames(x) != "off"],
  label = us_deaths$deaths,
  base_margin = us_deaths$off
)

set.seed(42)
mod <- xgboost::xgb.train(
  params = list(
    objective = "count:poisson",
    eval_metric = "rmse",
    eta = 1,
    subsample = 1,
    colsample_bynode = 1,
    min_child_weight = 1,
    max_depth = 2
  ),
  data = xgtrain,
  nrounds = 25
)

mod2 <- xgb_train_offset(
  x,
  us_deaths$deaths,
  "off",
  eta = 1,
  subsample = 1,
  colsample_bynode = 1,
  min_child_weight = 1,
  max_depth = 2,
  nrounds = 25,
  counts = FALSE
)

test_that("xgb_train_offset matches xgboost", {
  expect_equal(predict(mod, xgtrain), predict(mod2, xgtrain))
  expect_equal(predict(mod, xgtrain), xgb_predict_offset(mod2, xgtrain))
  expect_equal(predict(mod, xgtrain), xgb_predict_offset(mod2, x, "off"))
})

test_that("xgb_train_offset throws the correct errors and warnings", {
  expect_error(
    xgb_train_offset(x, us_deaths$deaths, "off", validation = -1),
    regexp = "`validation` should be"
  )
  expect_error(
    xgb_train_offset(x, us_deaths$deaths, "off", early_stop = 1),
    regexp = "`early_stop` should be"
  )
  expect_warning(
    xgb_train_offset(x, us_deaths$deaths, "off", early_stop = 99),
    regexp = "`early_stop` was reduced to"
  )
  expect_error(
    xgb_train_offset(x, us_deaths$deaths, "off", subsample = 1.01),
    regexp = "`subsample` should be"
  )
  expect_warning(
    xgb_train_offset(x, us_deaths$deaths, "off", min_child_weight = 1E3),
    regexp = "1000 samples were requested"
  )
  expect_error(
    xgb_train_offset(x, us_deaths$deaths),
    regexp = "A column named `offset` must be present"
  )
  expect_error(
    xgb_train_offset(x, us_deaths$deaths, "off", colsample_bynode = 0.5),
    regexp = "Please use a value >= 1"
  )
  expect_warning(
    xgb_train_offset(
      x,
      us_deaths$deaths,
      "off",
      objective = "reg:squarederror"
    ),
    regexp = "The following arguments are guarded"
  )
  expect_warning(
    xgb_train_offset(x, us_deaths$deaths, "off", params = list(eta = 1)),
    regexp = "Please supply elements of the `params` list"
  )
  expect_error(
    xgb_predict_offset(mod2, xgboost::xgb.DMatrix(x)),
    regexp = "If `new_data` is an `xgb.DMatrix`,"
  )
})

# standard formula for testing
f <- deaths ~ age_group + gender + year + off

test_that("boost_tree_offset() works", {
  xgb_off <- boost_tree_offset(
    learn_rate = 1,
    sample_size = 1,
    mtry = 11,
    min_n = 1,
    tree_depth = 2,
    trees = 25
  ) |>
    set_engine("xgboost_offset", offset_col = "off") |>
    fit(f, data = us_deaths)
  expect_identical(predict(mod, xgtrain), predict(xgb_off, us_deaths)$.pred)
  expect_identical(
    predict(mod, xgtrain),
    predict(xgb_off, us_deaths, type = "raw")
  )
})

rec <- recipes::recipe(deaths ~ age_group + gender + year + off, us_deaths) |>
  recipes::step_dummy(age_group, gender, one_hot = TRUE) |>
  recipes::step_rename(offset = off)

test_that("boost_tree_offset() works with recipes", {
  # rpart_exposure
  xgb_off <- workflows::workflow() |>
    workflows::add_recipe(rec) |>
    workflows::add_model(
      boost_tree_offset(
        learn_rate = 1,
        sample_size = 1,
        mtry = 11,
        min_n = 1,
        tree_depth = 2,
        trees = 25
      ) |>
        set_engine("xgboost_offset")
    ) |>
    fit(data = us_deaths)
  expect_identical(predict(mod, xgtrain), predict(xgb_off, us_deaths)$.pred)
})

test_that("finalize works", {
  mod_spec <- boost_tree_offset(
    mtry = tune(),
    trees = tune(),
    min_n = tune(),
    tree_depth = tune(),
    learn_rate = tune(),
    loss_reduction = tune(),
    sample_size = tune(),
    stop_iter = tune()
  ) |>
    set_engine("xgboost_offset")

  wf <- workflows::workflow() |>
    workflows::add_model(mod_spec) |>
    workflows::add_recipe(rec)

  param_grid <- data.frame(
    mtry = 4,
    trees = 11,
    min_n = 2,
    tree_depth = 3,
    learn_rate = 0.3,
    loss_reduction = 0,
    sample_size = 0.7,
    stop_iter = 7
  )

  expect_no_error(tune::finalize_workflow(wf, param_grid) |> fit(us_deaths))

  expect_equal(
    tune::finalize_model(mod_spec, param_grid)$args |>
      lapply(rlang::eval_tidy),
    as.list(param_grid)
  )
})

Try the offsetreg package in your browser

Any scripts or data that you put into this service are public.

offsetreg documentation built on March 23, 2026, 9:07 a.m.