tests/testthat/test-rule-fit-regression.R

test_that("formula method", {
  skip_on_cran()
  skip_if_not_installed("xrf")
  skip_if_not_installed("modeldata")

  chi_data <- make_chi_data()

  set.seed(4526)
  rf_fit_exp <-
    xrf::xrf(
      ridership ~ .,
      data = chi_data$chi_mod,
      family = "gaussian",
      xgb_control = list(nrounds = 3, min_child_weight = 3),
      verbose = 0
    )
  rf_pred_exp <- predict(rf_fit_exp, chi_data$chi_pred, lambda = 1)[, 1]

  expect_no_error(
    rf_mod <-
      rule_fit(trees = 3, min_n = 3, penalty = 1) |>
      set_engine("xrf") |>
      set_mode("regression")
  )

  set.seed(4526)
  expect_no_error(
    rf_fit <- fit(rf_mod, ridership ~ ., data = chi_data$chi_mod)
  )
  rf_pred <- predict(rf_fit, chi_data$chi_pred)

  expect_equal(
    unname(rf_fit_exp$xgb$evaluation_log),
    unname(rf_fit_exp$xgb$evaluation_log)
  )
  expect_equal(names(rf_pred), ".pred")
  expect_true(tibble::is_tibble(rf_pred))
  expect_equal(rf_pred$.pred, unname(rf_pred_exp))

  expect_no_error(
    rf_m_pred <- multi_predict(
      rf_fit,
      chi_data$chi_pred,
      penalty = chi_data$vals
    )
  )
  rf_m_pred <-
    rf_m_pred |>
    dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
    tidyr::unnest(cols = c(.pred)) |>
    dplyr::arrange(penalty, .row_number)

  for (i in chi_data$vals) {
    exp_pred <- predict(rf_fit_exp, chi_data$chi_pred, lambda = i)[, 1]
    obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> dplyr::pull(.pred)
    expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
  }
})

# ------------------------------------------------------------------------------

test_that("non-formula method", {
  skip_on_cran()
  skip_if_not_installed("xrf")
  skip_if_not_installed("modeldata")

  chi_data <- make_chi_data()

  set.seed(4526)
  rf_fit_exp <-
    xrf::xrf(
      ridership ~ .,
      data = chi_data$chi_mod,
      family = "gaussian",
      xgb_control = list(nrounds = 3, min_child_weight = 3),
      verbose = 0
    )
  rf_pred_exp <- predict(rf_fit_exp, chi_data$chi_pred, lambda = 1)[, 1]

  expect_no_error(
    rf_mod <-
      rule_fit(trees = 3, min_n = 3, penalty = 1) |>
      set_engine("xrf") |>
      set_mode("regression")
  )

  expect_no_error(
    rf_fit <- fit_xy(
      rf_mod,
      x = chi_data$chi_mod[, -1],
      y = chi_data$chi_mod$ridership
    )
  )
  rf_pred <- predict(rf_fit, chi_data$chi_pred)

  expect_equal(
    unname(rf_fit_exp$xgb$evaluation_log),
    unname(rf_fit$fit$xgb$evaluation_log)
  )

  expect_equal(names(rf_pred), ".pred")
  expect_true(tibble::is_tibble(rf_pred))
  expect_equal(rf_pred$.pred, unname(rf_pred_exp))

  expect_no_error(
    rf_m_pred <- multi_predict(
      rf_fit,
      chi_data$chi_pred,
      penalty = chi_data$vals
    )
  )
  rf_m_pred <-
    rf_m_pred |>
    dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
    tidyr::unnest(cols = c(.pred)) |>
    dplyr::arrange(penalty, .row_number)

  for (i in chi_data$vals) {
    exp_pred <- predict(rf_fit_exp, chi_data$chi_pred, lambda = i)[, 1]
    obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> dplyr::pull(.pred)
    expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
  }
})

# ------------------------------------------------------------------------------

test_that("tidy method - regression", {
  skip_on_cran()
  skip_if_not_installed("xrf")
  skip_if_not_installed("modeldata")

  ames_data <- make_ames_data()

  library(xrf)

  xrf_reg_mod <-
    rule_fit(trees = 3, penalty = .001) |>
    set_engine("xrf") |>
    set_mode("regression")

  set.seed(1)
  xrf_reg_fit <-
    xrf_reg_mod |>
    fit(
      Sale_Price ~
        Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air,
      data = ames_data$ames
    )

  xrf_rule_res <- tidy(xrf_reg_fit, penalty = .001)
  raw_coef <- coef(xrf_reg_fit$fit, lambda = 0.001)
  raw_coef <- raw_coef[raw_coef[, 1] != 0, ]
  expect_true(nrow(raw_coef) == nrow(xrf_rule_res))
  expect_true(all(raw_coef$term %in% xrf_rule_res$rule_id))

  xrf_col_res <- tidy(xrf_reg_fit, unit = "column", penalty = .001)
  expect_equal(
    sort(unique(xrf_col_res$term)),
    c(
      "(Intercept)",
      "Central_Air",
      "Gr_Liv_Area",
      "Latitude",
      "Longitude",
      "Neighborhood"
    )
  )
  expect_equal(
    sort(unique(raw_coef$term)),
    sort(unique(xrf_col_res$rule_id))
  )
})

test_that("early stopping works in xrf_fit", {
  skip_on_cran()
  skip_if_not_installed("xrf")
  skip_if_not_installed("modeldata")

  set.seed(1)
  reg_data <- modeldata::sim_regression(500)

  rf_mod_1 <-
    rule_fit(trees = 50, learn_rate = 1) |>
    set_engine("xrf", validation = 0.1) |>
    set_mode("regression")

  rf_mod_2 <-
    rule_fit(trees = 50, learn_rate = 1, stop_iter = 3) |>
    set_engine("xrf", validation = 0.1) |>
    set_mode("regression")

  rf_mod_3 <-
    rule_fit(trees = 50, learn_rate = 1, stop_iter = 5) |>
    set_engine("xrf", validation = 0.1) |>
    set_mode("regression")

  set.seed(2)
  expect_no_error(
    rf_fit_1 <- fit(rf_mod_1, outcome ~ ., data = reg_data)
  )

  set.seed(2)
  expect_no_error(
    rf_fit_2 <- fit(rf_mod_2, outcome ~ ., data = reg_data)
  )

  set.seed(2)
  expect_snapshot(
    suppressMessages(
      rf_fit_3 <- fit(rf_mod_3, outcome ~ ., data = reg_data)
    )
  )

  expect_false(did_stop_early(rf_fit_1))
  expect_true(did_stop_early(rf_fit_2))
  expect_true(did_stop_early(rf_fit_3))
})

test_that("xrf_fit is sensitive to glm_control", {
  skip_on_cran()
  skip_if_not_installed("xrf")

  rf_mod <-
    rule_fit(trees = 3) |>
    set_engine(
      "xrf",
      glm_control = list(type.measure = "deviance", nfolds = 8)
    ) |>
    set_mode("regression")

  expect_no_error(
    rf_fit_1 <- fit(rf_mod, mpg ~ ., data = mtcars)
  )

  rf_fit_1_call_args <- rlang::call_args(rf_fit_1$fit$glm$model$call)

  expect_equal(rf_fit_1_call_args$nfolds, 8)
  expect_equal(rf_fit_1_call_args$type.measure, "deviance")
})

test_that("xrf_fit guards xgb_control", {
  skip_on_cran()
  skip_if_not_installed("xrf")

  rf_mod <-
    rule_fit(trees = 3) |>
    set_engine("xrf", xgb_control = list(nrounds = 3)) |>
    set_mode("regression")

  expect_snapshot(
    suppressMessages(
      fit(rf_mod, mpg ~ ., data = mtcars)
    )
  )
})

Try the rules package in your browser

Any scripts or data that you put into this service are public.

rules documentation built on Jan. 27, 2026, 9:07 a.m.