Nothing
skip_if_not_installed("xgboost")
logregobj <- function(preds, dtrain) {
labels <- xgboost::getinfo(dtrain, "label")
preds <- 1 / (1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
xgb_bin_data <- xgboost::xgb.DMatrix(
as.matrix(mtcars[, -9]),
label = mtcars$am
)
xgb_list <- list(
reg_sqr = list(objective = "reg:squarederror"),
bin_log = list(objective = "binary:logitraw"),
reg_log = list(objective = "reg:logistic"),
bin_log = list(objective = "binary:logistic"),
log_reg = list(objective = logregobj),
reg_log_base = list(objective = "reg:logistic", base_score = mean(mtcars$am)),
bin_log_base = list(objective = "binary:logistic", base_score = mean(mtcars$am)),
reg_log_large = list(objective = "reg:logistic", nrounds = 50),
bin_log_large = list(objective = "binary:logistic", nrounds = 50),
reg_log_deep = list(objective = "reg:logistic", max_depth = 20),
bin_log_deep = list(objective = "binary:logistic", max_depth = 20)
) %>%
purrr::map(~ {
if (is.null(.x$base_score)) .x$base_score <- 0.5
if (is.null(.x$nrounds)) .x$nrounds <- 4
if (is.null(.x$max_depth)) .x$max_depth <- 2
.x
})
xgb_models_all <- xgb_list %>%
purrr::imap(~ {
xgboost::xgb.train(
params = list(
max_depth = 2, objective = .x$objective, base_score = .x$base_score
),
data = xgb_bin_data,
nrounds = .x$nrounds
)
})
xgb_models <- xgb_models_all[names(xgb_models_all) != "log_reg"]
test_that("Predictions match to model's predict routine", {
td_tests <- xgb_models %>%
purrr::map(
tidypredict_test,
df = mtcars,
xg_df = xgb_bin_data,
threshold = 0.001
)
td_tests %>%
purrr::imap(
~ {
msg <- paste0("------ >> MODEL: ", .y)
expect_false(.x$alert, info = msg)
}
)
expect_warning(
tidypredict_test(
xgb_models_all$log_reg,
df = mtcars,
xg_df = xgb_bin_data
)
)
})
test_that("Confirm SQL function returns SQL query", {
xgb_sql <- xgb_models %>%
purrr::map(tidypredict_sql, dbplyr::simulate_odbc())
# Removing "_large" models because of precision issues with other
# non M1 machines
no_large <- xgb_sql[!grepl("_large", names(xgb_sql))]
for (i in seq_along(no_large)) {
expect_snapshot(no_large[i])
}
})
test_that("Base scores match", {
xgb_scores <- xgb_list %>%
purrr::map_dbl(~ .x$base_score)
xgb_scores_pm <- xgb_models_all %>%
purrr::map(parse_model) %>%
purrr::map_dbl(~ .x$general$params$base_score)
xgb_scores %>%
seq_along() %>%
purrr::map(
~ expect_equal(xgb_scores[.x], xgb_scores_pm[.x])
)
})
test_that("Model can be saved and re-loaded", {
mp <- tempfile(fileext = ".yml")
yaml::write_yaml(parse_model(xgb_models$reg_sqr), mp)
l <- yaml::read_yaml(mp)
pm <- as_parsed_model(l)
expect_snapshot(tidypredict_fit(pm))
})
test_that("Predictions are correct for different objectives", {
m <- parsnip::fit(
parsnip::set_engine(parsnip::boost_tree(mode = "regression"), "xgboost"),
am ~ .,
data = mtcars
)
expect_false(tidypredict_test(m, df = mtcars, threshold = 0.001)$alert)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.