Nothing
test_that("formula method", {
skip_on_cran()
skip_if_not_installed("xrf")
skip_if_not_installed("modeldata")
chi_data <- make_chi_data()
set.seed(4526)
rf_fit_exp <-
xrf::xrf(
ridership ~ .,
data = chi_data$chi_mod,
family = "gaussian",
xgb_control = list(nrounds = 3, min_child_weight = 3),
verbose = 0
)
rf_pred_exp <- predict(rf_fit_exp, chi_data$chi_pred, lambda = 1)[, 1]
expect_no_error(
rf_mod <-
rule_fit(trees = 3, min_n = 3, penalty = 1) |>
set_engine("xrf") |>
set_mode("regression")
)
set.seed(4526)
expect_no_error(
rf_fit <- fit(rf_mod, ridership ~ ., data = chi_data$chi_mod)
)
rf_pred <- predict(rf_fit, chi_data$chi_pred)
expect_equal(
unname(rf_fit_exp$xgb$evaluation_log),
unname(rf_fit_exp$xgb$evaluation_log)
)
expect_equal(names(rf_pred), ".pred")
expect_true(tibble::is_tibble(rf_pred))
expect_equal(rf_pred$.pred, unname(rf_pred_exp))
expect_no_error(
rf_m_pred <- multi_predict(
rf_fit,
chi_data$chi_pred,
penalty = chi_data$vals
)
)
rf_m_pred <-
rf_m_pred |>
dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
tidyr::unnest(cols = c(.pred)) |>
dplyr::arrange(penalty, .row_number)
for (i in chi_data$vals) {
exp_pred <- predict(rf_fit_exp, chi_data$chi_pred, lambda = i)[, 1]
obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> dplyr::pull(.pred)
expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
}
})
# ------------------------------------------------------------------------------
test_that("non-formula method", {
skip_on_cran()
skip_if_not_installed("xrf")
skip_if_not_installed("modeldata")
chi_data <- make_chi_data()
set.seed(4526)
rf_fit_exp <-
xrf::xrf(
ridership ~ .,
data = chi_data$chi_mod,
family = "gaussian",
xgb_control = list(nrounds = 3, min_child_weight = 3),
verbose = 0
)
rf_pred_exp <- predict(rf_fit_exp, chi_data$chi_pred, lambda = 1)[, 1]
expect_no_error(
rf_mod <-
rule_fit(trees = 3, min_n = 3, penalty = 1) |>
set_engine("xrf") |>
set_mode("regression")
)
expect_no_error(
rf_fit <- fit_xy(
rf_mod,
x = chi_data$chi_mod[, -1],
y = chi_data$chi_mod$ridership
)
)
rf_pred <- predict(rf_fit, chi_data$chi_pred)
expect_equal(
unname(rf_fit_exp$xgb$evaluation_log),
unname(rf_fit$fit$xgb$evaluation_log)
)
expect_equal(names(rf_pred), ".pred")
expect_true(tibble::is_tibble(rf_pred))
expect_equal(rf_pred$.pred, unname(rf_pred_exp))
expect_no_error(
rf_m_pred <- multi_predict(
rf_fit,
chi_data$chi_pred,
penalty = chi_data$vals
)
)
rf_m_pred <-
rf_m_pred |>
dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
tidyr::unnest(cols = c(.pred)) |>
dplyr::arrange(penalty, .row_number)
for (i in chi_data$vals) {
exp_pred <- predict(rf_fit_exp, chi_data$chi_pred, lambda = i)[, 1]
obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> dplyr::pull(.pred)
expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
}
})
# ------------------------------------------------------------------------------
test_that("tidy method - regression", {
skip_on_cran()
skip_if_not_installed("xrf")
skip_if_not_installed("modeldata")
ames_data <- make_ames_data()
library(xrf)
xrf_reg_mod <-
rule_fit(trees = 3, penalty = .001) |>
set_engine("xrf") |>
set_mode("regression")
set.seed(1)
xrf_reg_fit <-
xrf_reg_mod |>
fit(
Sale_Price ~
Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air,
data = ames_data$ames
)
xrf_rule_res <- tidy(xrf_reg_fit, penalty = .001)
raw_coef <- coef(xrf_reg_fit$fit, lambda = 0.001)
raw_coef <- raw_coef[raw_coef[, 1] != 0, ]
expect_true(nrow(raw_coef) == nrow(xrf_rule_res))
expect_true(all(raw_coef$term %in% xrf_rule_res$rule_id))
xrf_col_res <- tidy(xrf_reg_fit, unit = "column", penalty = .001)
expect_equal(
sort(unique(xrf_col_res$term)),
c(
"(Intercept)",
"Central_Air",
"Gr_Liv_Area",
"Latitude",
"Longitude",
"Neighborhood"
)
)
expect_equal(
sort(unique(raw_coef$term)),
sort(unique(xrf_col_res$rule_id))
)
})
test_that("early stopping works in xrf_fit", {
skip_on_cran()
skip_if_not_installed("xrf")
skip_if_not_installed("modeldata")
set.seed(1)
reg_data <- modeldata::sim_regression(500)
rf_mod_1 <-
rule_fit(trees = 50, learn_rate = 1) |>
set_engine("xrf", validation = 0.1) |>
set_mode("regression")
rf_mod_2 <-
rule_fit(trees = 50, learn_rate = 1, stop_iter = 3) |>
set_engine("xrf", validation = 0.1) |>
set_mode("regression")
rf_mod_3 <-
rule_fit(trees = 50, learn_rate = 1, stop_iter = 5) |>
set_engine("xrf", validation = 0.1) |>
set_mode("regression")
set.seed(2)
expect_no_error(
rf_fit_1 <- fit(rf_mod_1, outcome ~ ., data = reg_data)
)
set.seed(2)
expect_no_error(
rf_fit_2 <- fit(rf_mod_2, outcome ~ ., data = reg_data)
)
set.seed(2)
expect_snapshot(
suppressMessages(
rf_fit_3 <- fit(rf_mod_3, outcome ~ ., data = reg_data)
)
)
expect_false(did_stop_early(rf_fit_1))
expect_true(did_stop_early(rf_fit_2))
expect_true(did_stop_early(rf_fit_3))
})
test_that("xrf_fit is sensitive to glm_control", {
skip_on_cran()
skip_if_not_installed("xrf")
rf_mod <-
rule_fit(trees = 3) |>
set_engine(
"xrf",
glm_control = list(type.measure = "deviance", nfolds = 8)
) |>
set_mode("regression")
expect_no_error(
rf_fit_1 <- fit(rf_mod, mpg ~ ., data = mtcars)
)
rf_fit_1_call_args <- rlang::call_args(rf_fit_1$fit$glm$model$call)
expect_equal(rf_fit_1_call_args$nfolds, 8)
expect_equal(rf_fit_1_call_args$type.measure, "deviance")
})
test_that("xrf_fit guards xgb_control", {
skip_on_cran()
skip_if_not_installed("xrf")
rf_mod <-
rule_fit(trees = 3) |>
set_engine("xrf", xgb_control = list(nrounds = 3)) |>
set_mode("regression")
expect_snapshot(
suppressMessages(
fit(rf_mod, mpg ~ ., data = mtcars)
)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.