Nothing
test_that("boost_tree with lightgbm",{
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
suppressPackageStartupMessages({
library(lightgbm)
library(dplyr)
})
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
expect_snapshot(
boost_tree() %>% set_engine("lightgbm") %>% set_mode("regression")
)
expect_snapshot(
boost_tree() %>% set_engine("lightgbm", nrounds = 100) %>% set_mode("classification")
)
# regression -----------------------------------------------------------------
expect_error_free({
pars_fit_1 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_error_free({
pars_preds_1 <-
predict(pars_fit_1, penguins)
})
peng <-
penguins %>%
mutate(across(where(is.character), ~as.factor(.x))) %>%
mutate(across(where(is.factor), ~as.integer(.x) - 1))
peng_y <- peng$bill_length_mm
peng_m <- peng %>%
select(-bill_length_mm) %>%
as.matrix()
peng_x <-
lgb.Dataset(
data = peng_m,
label = peng_y,
params = list(feature_pre_filter = FALSE),
categorical_feature = c(1L, 2L, 6L)
)
params_1 <- list(
objective = "regression"
)
lgbm_fit_1 <-
lightgbm::lgb.train(
data = peng_x,
params = params_1,
verbose = -1
)
lgbm_preds_1 <- predict(lgbm_fit_1, peng_m)
expect_equal(pars_preds_1$.pred, lgbm_preds_1)
# regression, adjusting a primary argument
expect_error_free({
pars_fit_2 <-
boost_tree(trees = 20) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_error_free({
pars_preds_2 <-
predict(pars_fit_2, penguins)
})
params_2 <- list(
objective = "regression",
num_iterations = 20
)
lgbm_fit_2 <-
lightgbm::lgb.train(
data = peng_x,
params = params_2,
verbose = -1
)
lgbm_preds_2 <- predict(lgbm_fit_2, peng_m)
expect_equal(pars_preds_2$.pred, lgbm_preds_2)
# regression, adjusting an engine argument
expect_error_free({
pars_fit_3 <-
boost_tree() %>%
set_engine("lightgbm", lambda_l2 = .5) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_error_free({
pars_preds_3 <-
predict(pars_fit_3, penguins)
})
params_3 <- list(
objective = "regression",
lambda_l2 = .5
)
lgbm_fit_3 <-
lightgbm::lgb.train(
data = peng_x,
params = params_3,
verbose = -1
)
lgbm_preds_3 <- predict(lgbm_fit_3, peng_m)
expect_equal(pars_preds_3$.pred, lgbm_preds_3)
# classification -------------------------------------------------------------
# multiclass
expect_error_free({
pars_fit_4 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("classification") %>%
fit(species ~ ., data = penguins)
})
expect_error_free({
pars_preds_4 <-
predict(pars_fit_4, penguins, type = "prob")
pars_preds_raw_4 <-
predict(pars_fit_4, penguins, type = "raw")
})
expect_equal(nrow(pars_preds_raw_4), nrow(penguins))
expect_equal(ncol(pars_preds_raw_4), 3)
pars_preds_4_mtx <- as.matrix(pars_preds_4)
dimnames(pars_preds_4_mtx) <- NULL
peng_y_c <- peng$species
peng_m_c <- peng %>%
select(-species) %>%
as.matrix()
peng_x_c <-
lgb.Dataset(
data = peng_m_c,
label = peng_y_c,
params = list(feature_pre_filter = FALSE),
categorical_feature = c(1L, 6L),
)
params_4 <- list(
objective = "multiclass",
num_class = 3
)
lgbm_fit_4 <-
lightgbm::lgb.train(
data = peng_x_c,
params = params_4,
verbose = -1
)
lgbm_preds_4 <-
predict(lgbm_fit_4, peng_m_c) %>%
reshape_lightgbm_multiclass_preds(num_rows = nrow(peng_m_c))
expect_equal(pars_preds_4_mtx, lgbm_preds_4)
# check class predictions
pars_preds_5 <-
predict(pars_fit_4, penguins, type = "class") %>%
`[[`(".pred_class") %>%
as.character()
lgbm_preds_5 <- apply(pars_preds_4_mtx, 1, function(x) which.max(x)) %>%
factor(labels = c("Adelie", "Chinstrap", "Gentoo")) %>%
as.character()
expect_equal(pars_preds_5, lgbm_preds_5)
# classification on a two-level outcome
expect_error_free({
pars_fit_6 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("classification") %>%
fit(sex ~ ., data = penguins)
})
expect_error_free({
pars_preds_6 <-
predict(pars_fit_6, penguins, type = "prob")
pars_preds_raw_6 <-
predict(pars_fit_6, penguins, type = "raw")
})
expect_equal(length(pars_preds_raw_6), nrow(penguins))
expect_false(identical(pars_preds_6, pars_preds_raw_6))
pars_preds_6_b <- pars_preds_6$.pred_male
peng_y_b <- peng$sex
peng_m_b <- peng %>%
select(-sex) %>%
as.matrix()
peng_x_b <-
lgb.Dataset(
data = peng_m_b,
label = peng_y_b,
params = list(feature_pre_filter = FALSE),
categorical_feature = c(1L, 2L),
)
params_6 <- list(
objective = "binary",
num_class = 1
)
lgbm_fit_6 <-
lightgbm::lgb.train(
data = peng_x_b,
params = params_6,
verbose = -1
)
lgbm_preds_6 <- predict(lgbm_fit_6, peng_m_b)
expect_equal(pars_preds_6_b, lgbm_preds_6)
})
test_that("bonsai correctly determines objective when label is a factor", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
suppressPackageStartupMessages({
library(lightgbm)
library(dplyr)
})
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
expect_error_free({
bst <- train_lightgbm(
x = penguins[, c("bill_length_mm", "bill_depth_mm")],
y = penguins[["sex"]],
num_iterations = 5,
verbose = -1L
)
})
expect_equal(bst$params$objective, "binary")
expect_equal(bst$params$num_class, 1)
expect_error_free({
bst <- train_lightgbm(
x = penguins[, c("bill_length_mm", "bill_depth_mm")],
y = penguins[["species"]],
num_iterations = 5,
verbose = -1L
)
})
expect_equal(bst$params$objective, "multiclass")
expect_equal(bst$params$num_class, 3)
})
test_that("bonsai handles mtry vs mtry_prop gracefully", {
skip_if_not_installed("modeldata")
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
# supply no mtry
expect_error_free({
pars_fit_1 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_equal(
extract_fit_engine(pars_fit_1)$params$feature_fraction_bynode,
1
)
# supply mtry = 1 (edge cases)
expect_error_free({
pars_fit_2 <-
boost_tree(mtry = 1) %>%
set_engine("lightgbm", counts = TRUE) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_equal(
extract_fit_engine(pars_fit_2)$params$feature_fraction_bynode,
1 / (ncol(penguins) - 1)
)
expect_error_free({
pars_fit_3 <-
boost_tree(mtry = 1) %>%
set_engine("lightgbm", counts = FALSE) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_equal(
extract_fit_engine(pars_fit_3)$params$feature_fraction_bynode,
1
)
# supply a count (with default counts = TRUE)
expect_error_free({
pars_fit_4 <-
boost_tree(mtry = 3) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
expect_equal(
extract_fit_engine(pars_fit_4)$params$feature_fraction_bynode,
3 / (ncol(penguins) - 1)
)
# supply a proportion when count expected
expect_snapshot_error({
pars_fit_5 <-
boost_tree(mtry = .5) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
# supply a count when proportion expected
expect_snapshot_error({
pars_fit_6 <-
boost_tree(mtry = 3) %>%
set_engine("lightgbm", counts = FALSE) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})
# supply a feature fraction argument rather than mtry
# TODO: is there any way to extend parsnip's warning here to
# point users to mtry?
expect_warning({
pars_fit_7 <-
boost_tree() %>%
set_engine("lightgbm", feature_fraction_bynode = .5) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)},
"manually modified and were removed: feature_fraction_bynode."
)
expect_equal(
extract_fit_engine(pars_fit_7)$params$feature_fraction_bynode,
1
)
# supply both feature fraction and mtry
expect_snapshot({
pars_fit_8 <-
boost_tree(mtry = .5) %>%
set_engine("lightgbm", feature_fraction_bynode = .5) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)},
error = TRUE
)
expect_warning({
pars_fit_9 <-
boost_tree(mtry = 2) %>%
set_engine("lightgbm", feature_fraction_bynode = .5) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)},
"manually modified and were removed: feature_fraction_bynode."
)
expect_equal(
extract_fit_engine(pars_fit_9)$params$feature_fraction_bynode,
2 / (ncol(penguins) - 1)
)
})
test_that("tuning mtry vs mtry_prop", {
skip_if_not_installed("tune")
skip_if_not_installed("rsample")
skip_if_not_installed("modeldata")
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
set.seed(1)
suppressMessages(
expect_error_free({
gbm_tune <- tune::tune_grid(
boost_tree(mtry = tune::tune()) %>%
set_engine("lightgbm") %>%
set_mode("regression"),
grid = 4,
preprocessor = bill_length_mm ~ .,
resamples = rsample::bootstraps(penguins, times = 5)
)
})
)
mtrys <- unique(tune::collect_metrics(gbm_tune)$mtry)
expect_equal(length(mtrys), 4)
expect_true(all(mtrys >= 1))
# supply tune() without tuning
expect_snapshot({
boost_tree(mtry = tune::tune()) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)},
error = TRUE
)
})
test_that("training wrapper warns on protected arguments", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
expect_warning(
boost_tree() %>%
set_engine("lightgbm", colnames = paste0("X", 1:ncol(penguins))) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins),
"guarded by bonsai.*colnames"
)
expect_warning(
boost_tree() %>%
set_engine(
"lightgbm",
colnames = paste0("X", 1:ncol(penguins)),
callbacks = list(p = print)
) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins),
"guarded by bonsai.*colnames, callbacks"
)
expect_warning(
boost_tree() %>%
set_engine(
"lightgbm",
colnames = paste0("X", 1:ncol(penguins))
) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins),
"guarded by bonsai.*colnames"
)
expect_snapshot(
error = TRUE,
boost_tree() %>%
set_engine("lightgbm", n_iter = 10) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_snapshot(
error = TRUE,
boost_tree() %>%
set_engine("lightgbm", num_tree = 10) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_snapshot(
error = TRUE,
boost_tree() %>%
set_engine("lightgbm", min_split_gain = 2) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_snapshot(
error = TRUE,
boost_tree() %>%
set_engine("lightgbm", min_split_gain = 2, lambda_l2 = .5) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
})
test_that("training wrapper passes stop_iter correctly", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
expect_error_free(
pars_fit_1 <-
boost_tree(stop_iter = 10) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_warning(
pars_fit_2 <-
boost_tree() %>%
set_engine("lightgbm", early_stopping_round = 10) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins),
"were removed: early_stopping_round"
)
expect_error_free(
pars_fit_3 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_error_free(
pars_fit_4 <-
boost_tree() %>%
set_engine("lightgbm", validation = .2) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
expect_error_free(
pars_fit_5 <-
boost_tree(stop_iter = 10) %>%
set_engine("lightgbm", validation = .2) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
)
# detect early_stopping round in the model fit
expect_equal(pars_fit_1$fit$params$early_stopping_round, 10)
expect_null( pars_fit_2$fit$params$early_stopping_round)
expect_null( pars_fit_3$fit$params$early_stopping_round)
expect_null( pars_fit_4$fit$params$early_stopping_round)
expect_equal(pars_fit_5$fit$params$early_stopping_round, 10)
# detect validation in the model fit
expect_true(!is.na(pars_fit_1$fit$best_score))
expect_true( is.na(pars_fit_2$fit$best_score))
expect_true( is.na(pars_fit_3$fit$best_score))
expect_true(!is.na(pars_fit_4$fit$best_score))
expect_true(!is.na(pars_fit_5$fit$best_score))
})
test_that("training wrapper handles bagging correctly", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
pars_fit_1 <-
boost_tree() %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
pars_fit_2 <-
boost_tree(sample_size = .5) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
pars_fit_3 <-
boost_tree(sample_size = .5) %>%
set_engine("lightgbm", bagging_freq = 2) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
expect_equal(pars_fit_1$fit$params$bagging_fraction, 1)
expect_null( pars_fit_1$fit$params$bagging_freq)
expect_equal(pars_fit_2$fit$params$bagging_fraction, .5)
expect_equal(pars_fit_2$fit$params$bagging_freq, 1)
expect_equal(pars_fit_3$fit$params$bagging_fraction, .5)
expect_equal(pars_fit_3$fit$params$bagging_freq, 2)
})
test_that("multi_predict() predicts classes if 'type' not given ", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")
suppressPackageStartupMessages({
library(lightgbm)
library(dplyr)
})
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
penguins_subset %>%
mutate(across(where(is.character), ~as.factor(.x))) %>%
mutate(across(where(is.factor), ~as.integer(.x) - 1))
num_iterations <- 5
# classification (multiclass) ------------------------------------------------
expect_error_free({
clf_multiclass_fit <-
boost_tree(trees = num_iterations) %>%
set_engine("lightgbm") %>%
set_mode("classification") %>%
fit(species ~ ., data = penguins)
})
expect_equal(clf_multiclass_fit$fit$current_iter(), num_iterations)
new_data <-
penguins_subset_numeric %>%
select(-species) %>%
as.matrix()
multi_preds <-
multi_predict(
clf_multiclass_fit,
new_data = new_data[1, , drop = FALSE],
trees = seq_len(num_iterations)
)
# should be a tibble
pred_tbl <- multi_preds$.pred[[1]]
expect_s3_class(pred_tbl, "tbl_df")
# should look like class predictions
expect_named(pred_tbl, c("trees", ".pred_class"))
expect_s3_class(pred_tbl[[".pred_class"]], "factor")
expect_true(all(as.character(pred_tbl[[".pred_class"]]) %in% levels(penguins[["species"]])))
# classification (binary) ------------------------------------------------
expect_error_free({
clf_binary_fit <-
boost_tree(trees = num_iterations) %>%
set_engine("lightgbm") %>%
set_mode("classification") %>%
fit(sex ~ ., data = penguins)
})
expect_equal(clf_binary_fit$fit$current_iter(), num_iterations)
new_data <-
penguins_subset_numeric %>%
select(-sex) %>%
as.matrix()
multi_preds <-
multi_predict(
clf_binary_fit,
new_data = new_data[1, , drop = FALSE],
trees = seq_len(num_iterations)
)
# should be a tibble
pred_tbl <- multi_preds$.pred[[1]]
expect_s3_class(pred_tbl, "tbl_df")
# should look like class predictions
expect_named(pred_tbl, c("trees", ".pred_class"))
expect_s3_class(pred_tbl[[".pred_class"]], "factor")
expect_true(all(as.character(pred_tbl[[".pred_class"]]) %in% levels(penguins[["sex"]])))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.