Nothing
#' Log conditional
#'
#' @param log log (TRUE|FALSE)
#' @param text text string to be logged
#' @return prints log on screen (if log == TRUE).
log_info_if <- function(log = TRUE, text = "log") {
if (log) {message(text)}
}
#' Explain a binary target using xgboost
#'
#' Based on the hyperparameters defined in the setup parameter, XGBoost hyperparameter-tuning is
#' carried out using cross-validation. The best model is chosen and returned.
#' As default, the function returns the feature-importance plot.
#' To get the all outputs, use parameter out = "all"
#'
#' @param data Data frame, must contain variable defined in target,
#' but should not contain any customer-IDs or date/period columns
#' @param target Target variable (must be binary 0/1, FALSE/TRUE, no/yes)
#' @param log Log?
#' @param nthread Number of threads used for training
#' @param setup Setup of model
#' @param out Output of the function: "plot" | "model" | "importance" | all"
#' @return Plot of importance (if out = "plot")
#' @examples
#' data <- use_data_iris()
#' data$is_versicolor <- ifelse(data$Species == "versicolor", 1, 0)
#' data$Species <- NULL
#' explain_xgboost(data, target = is_versicolor, log = FALSE)
#' @export
explain_xgboost <- function(data, target, log = TRUE, nthread = 1,
setup = list(
cv_nfold = 2, # Nr. of folds used for cross-validation during model training
max_nrounds = 1000,
early_stopping_rounds = 50,
grid_xgboost = list(
eta = c(0.3, 0.1, 0.01),
max_depth = c(3, 5),
gamma = 0,
colsample_bytree = 0.8,
subsample = 0.8,
min_child_weight = 1,
scale_pos_weight = 1
)),
out = "plot") {
# check if xgboost is installed
rlang::check_installed("xgboost", reason = "to create a xgboost model.")
# chech data & target
check_data_frame_non_empty(data)
rlang::check_required(target)
# tidy eval for target
target_quo <- enquo(target)
target_txt <- quo_name(target_quo)[[1]]
if (!target_txt %in% names(data)) {
warning("target must be a variable of data")
return(invisible())
}
# check if target is binary
if (length(unique(data[[target_txt]])) != 2) {
warning("target must be binary (e.g. 0/1, TRUE/FALSE, 'yes'/'no')")
return(invisible())
}
# undefined variables to check CRAN tests
variable <- NULL
iter <- NULL
train_auc_mean <- NULL
test_auc_mean <- NULL
model_nr <- NULL
best_iter_ind <- NULL
runtime <- NULL
# define hy-param grid
param_grid <- expand.grid(
#min_prop_train = setup$min_prop_train,
eta = setup$grid_xgboost$eta,
max_depth = setup$grid_xgboost$max_depth,
gamma = setup$grid_xgboost$gamma,
colsample_bytree = setup$grid_xgboost$colsample_bytree,
subsample = setup$grid_xgboost$subsample,
min_child_weight = setup$grid_xgboost$min_child_weight,
scale_pos_weight = setup$grid_xgboost$scale_pos_weight
)
# log details?
verbose <- FALSE
# prepare to remember
all_auc <- vector(mode = "numeric")
all_nrounds <- vector(mode = "numeric")
hp_tuning_log <- NULL
# train with cross validation ---------------------------------------------
# xgb.cv loop
k <- 1
for (k in seq_len(nrow(param_grid))) {
current_params <- stats::setNames(
as.list(t(param_grid[k, ])),
names(param_grid)
)
# prepare target
dtrain <- xgboost::xgb.DMatrix(
as.matrix(data[ ,names(data) != target_txt]),
label = data[[target_txt]])
t1 <- Sys.time()
log_info_if(log, paste("\ntrain xgboost nr", k))
log_str <- paste0(paste0(names(current_params), "=", current_params), collapse=", ")
log_info_if(log, paste0("", log_str))
# training
set.seed(42)
cv <- xgboost::xgb.cv(
objective = "binary:logistic",
eval_metric = "auc",
data = dtrain,
#label = ltrain,
params = current_params,
nthread = nthread,
nfold = setup$cv_nfold,
nrounds = setup$max_nrounds,
early_stopping_rounds = setup$early_stopping_rounds,
maximize = TRUE,
verbose = FALSE, #verbose, #log details?
print_every_n = 100
)
#log_info_if(log, paste("xgboost nr", k, "training finished"))
all_nrounds[k] <- cv$best_iteration
model_log <- cv$evaluation_log[cv$evaluation_log$iter == cv$best_iteration, ]
all_auc[k] <- model_log$test_auc_mean
t2 <- Sys.time()
runtime_curr <- round(difftime(t2, t1, units = "mins"), 1)
#store test_auc for each xgb iter
hp_tuning_log_curr <- cv$evaluation_log %>%
dplyr::select(iter, train_auc_mean, test_auc_mean) %>%
dplyr::mutate(model_nr = k,
min_prop_train = current_params$min_prop_train,
eta = current_params$eta,
gamma = current_params$gamma,
max_depth = current_params$max_depth,
min_child_weight = current_params$min_child_weight,
subsample = current_params$subsample,
colsample_bytree = current_params$colsample_bytree,
scale_pos_weight = current_params$scale_pos_weight,
best_iter_ind = ifelse(iter == cv$best_iteration, 1, 0),
runtime = runtime_curr,
.before = iter)
hp_tuning_log <- rbind(hp_tuning_log, hp_tuning_log_curr)
rm(hp_tuning_log_curr)
log_info_if(log, paste0("",
all_nrounds[k], " iterations, ",
"training time=", runtime_curr, " min, ",
"auc=", round(all_auc[k],4)))
} #end xgb.cv loop
# Plot & Save Hy-Param Tuning Runs
suppressWarnings({
suppressMessages({
plot_hp_tuning <- hp_tuning_log %>% #nolint
dplyr::mutate(model_nr = as.factor(model_nr)) %>%
ggplot2::ggplot(ggplot2::aes(x = iter, y = test_auc_mean, color = model_nr)) +
ggplot2::geom_line(size = 1.5) +
ggplot2::labs(x = "XGBoost Training Iteration", y = "Mean Test AUC (CV)",
title = "Hyperparameter-Tuning") +
ggplot2::theme_light()
})
})
# Store Hy-Param Tuning Run Results as df
hp_tuning_best <- hp_tuning_log %>%
dplyr::filter(best_iter_ind == 1)
# identify cols w/o variance for removal
drop <- hp_tuning_best %>%
dplyr::select(-runtime, -iter) %>% #should not be dropped, even if constant across all runs
explore::describe() %>%
dplyr::filter(unique == 1) %>%
dplyr::pull(variable)
# remove constant hy-params from result df
hp_tuning_best <- hp_tuning_best %>%
dplyr::select(-dplyr::one_of(drop))
# final model selection & training ----------------------------------------
# Select best params
best_idx <- which(all_auc == max(all_auc))[1]
best_auc <- all_auc[best_idx]
best_nrounds <- all_nrounds[best_idx]
best_params <- stats::setNames(
as.list(t(param_grid[best_idx, ])),
names(param_grid)
)
log_info_if(log, paste0("\nbest model found in cross-validation: xgboost nr ", best_idx, ": "))
log_info_if(log, paste0("", paste0(paste0(names(current_params), "=", current_params), collapse=", ")))
log_info_if(log, paste0("nrounds=", best_nrounds,", test_auc=", round(best_auc, 4)))
log_info_if(log, paste0("\ntrain final model..."))
# train final model, no cross validation
t1 <- Sys.time()
model <- xgboost::xgb.train(
data = dtrain,
nrounds = best_nrounds,
nthread = 1, #setup$nthread,
booster = "gbtree",
objective = "binary:logistic",
params = best_params,
verbose = verbose,
print_every_n = 100
)
t2 <- Sys.time()
runtime_curr <- round(difftime(t2, t1, units = "mins"), 1)
# feature importance
importance = xgboost::xgb.importance(colnames(dtrain), model = model)
names(importance) <- c("variable", "gain", "cover", "frequency")
importance$importance <- importance$gain
importance <- importance %>% dplyr::arrange(dplyr::desc(importance))
# plot importance
p <- importance %>%
utils::head(30) %>%
ggplot2::ggplot(ggplot2::aes(
x = importance,
y = forcats::fct_reorder(variable, importance)
)) +
ggplot2::geom_col(color = "white", fill = "grey") +
ggplot2::ylab("variable") +
ggplot2::ggtitle("ML-feature-importance") +
ggplot2::theme_minimal()
# log
log_info_if(log, paste0("done, ", "training time=", runtime_curr, " min"))
# return result -----------------------------------------------------------
model <- list(
model = model,
importance = importance,
plot = p,
tune_data = hp_tuning_best,
tune_plot = plot_hp_tuning
)
# output
if (out %in% c("all", "list")) {
return(model)
} else if(out == "model") {
return(model$model)
} else if(out == "importance") {
return(model$importance)
}
# default output
model$plot
} # explain_xgboost
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.