R/explain-xgboost.R

Defines functions explain_xgboost log_info_if

Documented in explain_xgboost log_info_if

#' Log conditional
#'
#' @param log log (TRUE|FALSE)
#' @param text text string to be logged
#' @return prints log on screen (if log == TRUE).

log_info_if <- function(log = TRUE, text = "log") {
  if (log) {message(text)}
}

#' Explain a binary target using xgboost
#'
#' Based on the hyperparameters defined in the setup parameter, XGBoost hyperparameter-tuning is
#' carried out using cross-validation. The best model is chosen and returned.
#' As default, the function returns the feature-importance plot.
#' To get the all outputs, use parameter out = "all"
#'
#' @param data Data frame, must contain variable defined in target,
#' but should not contain any customer-IDs or date/period columns
#' @param target Target variable (must be binary 0/1, FALSE/TRUE, no/yes)
#' @param log Log?
#' @param nthread Number of threads used for training
#' @param setup Setup of model
#' @param out Output of the function: "plot" | "model" | "importance" | all"
#' @return Plot of importance (if out = "plot")
#' @examples
#' data <- use_data_iris()
#' data$is_versicolor <- ifelse(data$Species == "versicolor", 1, 0)
#' data$Species <- NULL
#' explain_xgboost(data, target = is_versicolor, log = FALSE)
#' @export

explain_xgboost <- function(data, target, log = TRUE, nthread = 1,
                               setup = list(
                                 cv_nfold = 2, # Nr. of folds used for cross-validation during model training
                                 max_nrounds = 1000,
                                 early_stopping_rounds = 50,
                                 grid_xgboost = list(
                                   eta = c(0.3, 0.1, 0.01),
                                   max_depth = c(3, 5),
                                   gamma = 0,
                                   colsample_bytree = 0.8,
                                   subsample = 0.8,
                                   min_child_weight = 1,
                                   scale_pos_weight = 1
                               )),
                            out = "plot") {

  # check if xgboost is installed
  rlang::check_installed("xgboost", reason = "to create a xgboost model.")

  # chech data & target
  check_data_frame_non_empty(data)
  rlang::check_required(target)

  # tidy eval for target
  target_quo <- enquo(target)
  target_txt <- quo_name(target_quo)[[1]]
  if (!target_txt %in% names(data)) {
    warning("target must be a variable of data")
    return(invisible())
  }

  # check if target is binary
  if (length(unique(data[[target_txt]])) != 2)  {
    warning("target must be binary (e.g. 0/1, TRUE/FALSE, 'yes'/'no')")
    return(invisible())
  }


  # undefined variables to check CRAN tests
  variable <- NULL
  iter <- NULL
  train_auc_mean <- NULL
  test_auc_mean <- NULL
  model_nr <- NULL
  best_iter_ind <- NULL
  runtime <- NULL

  # define hy-param grid
  param_grid <- expand.grid(
    #min_prop_train = setup$min_prop_train,
    eta = setup$grid_xgboost$eta,
    max_depth = setup$grid_xgboost$max_depth,
    gamma = setup$grid_xgboost$gamma,
    colsample_bytree = setup$grid_xgboost$colsample_bytree,
    subsample = setup$grid_xgboost$subsample,
    min_child_weight = setup$grid_xgboost$min_child_weight,
    scale_pos_weight = setup$grid_xgboost$scale_pos_weight
  )

  # log details?
  verbose <- FALSE

  # prepare to remember
  all_auc <- vector(mode = "numeric")
  all_nrounds <- vector(mode = "numeric")
  hp_tuning_log <- NULL


  # train with cross validation ---------------------------------------------

  # xgb.cv loop
  k <- 1
  for (k in seq_len(nrow(param_grid))) {

    current_params <- stats::setNames(
      as.list(t(param_grid[k, ])),
      names(param_grid)
    )

    # prepare target
    dtrain <- xgboost::xgb.DMatrix(
      as.matrix(data[ ,names(data) != target_txt]),
      label = data[[target_txt]])

    t1 <- Sys.time()

    log_info_if(log, paste("\ntrain xgboost nr", k))
    log_str <- paste0(paste0(names(current_params), "=", current_params), collapse=", ")
    log_info_if(log, paste0("", log_str))

    # training
    set.seed(42)
    cv <- xgboost::xgb.cv(
      objective = "binary:logistic",
      eval_metric = "auc",
      data = dtrain,
      #label = ltrain,
      params = current_params,
      nthread = nthread,
      nfold = setup$cv_nfold,
      nrounds = setup$max_nrounds,
      early_stopping_rounds = setup$early_stopping_rounds,
      maximize = TRUE,
      verbose = FALSE, #verbose,     #log details?
      print_every_n = 100
    )
    #log_info_if(log, paste("xgboost nr", k, "training finished"))
    all_nrounds[k] <- cv$best_iteration
    model_log <- cv$evaluation_log[cv$evaluation_log$iter == cv$best_iteration, ]
    all_auc[k] <- model_log$test_auc_mean

    t2 <- Sys.time()
    runtime_curr <- round(difftime(t2, t1, units = "mins"), 1)

    #store test_auc for each xgb iter
    hp_tuning_log_curr <- cv$evaluation_log %>%
      dplyr::select(iter, train_auc_mean, test_auc_mean) %>%
      dplyr::mutate(model_nr = k,
                    min_prop_train = current_params$min_prop_train,
                    eta = current_params$eta,
                    gamma = current_params$gamma,
                    max_depth = current_params$max_depth,
                    min_child_weight = current_params$min_child_weight,
                    subsample = current_params$subsample,
                    colsample_bytree = current_params$colsample_bytree,
                    scale_pos_weight = current_params$scale_pos_weight,
                    best_iter_ind = ifelse(iter == cv$best_iteration, 1, 0),
                    runtime = runtime_curr,
                    .before = iter)
    hp_tuning_log <- rbind(hp_tuning_log, hp_tuning_log_curr)
    rm(hp_tuning_log_curr)

    log_info_if(log, paste0("",
      all_nrounds[k], " iterations, ",
      "training time=", runtime_curr, " min, ",
      "auc=", round(all_auc[k],4)))

  } #end xgb.cv loop

  # Plot & Save Hy-Param Tuning Runs
  suppressWarnings({
    suppressMessages({
      plot_hp_tuning <- hp_tuning_log %>%  #nolint
        dplyr::mutate(model_nr = as.factor(model_nr)) %>%
        ggplot2::ggplot(ggplot2::aes(x = iter, y = test_auc_mean, color = model_nr)) +
        ggplot2::geom_line(size = 1.5) +
        ggplot2::labs(x = "XGBoost Training Iteration", y = "Mean Test AUC (CV)",
                      title = "Hyperparameter-Tuning") +
        ggplot2::theme_light()
    })
  })

  # Store Hy-Param Tuning Run Results as df
  hp_tuning_best <- hp_tuning_log %>%
    dplyr::filter(best_iter_ind == 1)

  # identify cols w/o variance for removal
  drop <- hp_tuning_best %>%
    dplyr::select(-runtime, -iter) %>%  #should not be dropped, even if constant across all runs
    explore::describe() %>%
    dplyr::filter(unique == 1) %>%
    dplyr::pull(variable)

  # remove constant hy-params from result df
  hp_tuning_best <- hp_tuning_best %>%
    dplyr::select(-dplyr::one_of(drop))


# final model selection & training ----------------------------------------

  # Select best params
  best_idx <- which(all_auc == max(all_auc))[1]
  best_auc <- all_auc[best_idx]
  best_nrounds <- all_nrounds[best_idx]
  best_params <- stats::setNames(
    as.list(t(param_grid[best_idx, ])),
    names(param_grid)
  )

  log_info_if(log, paste0("\nbest model found in cross-validation: xgboost nr ", best_idx, ": "))
  log_info_if(log, paste0("", paste0(paste0(names(current_params), "=", current_params), collapse=", ")))
  log_info_if(log, paste0("nrounds=", best_nrounds,", test_auc=", round(best_auc, 4)))

  log_info_if(log, paste0("\ntrain final model..."))

  # train final model, no cross validation
  t1 <- Sys.time()
  model <- xgboost::xgb.train(
    data = dtrain,
    nrounds = best_nrounds,
    nthread = 1, #setup$nthread,
    booster = "gbtree",
    objective = "binary:logistic",
    params = best_params,
    verbose = verbose,
    print_every_n = 100
  )
  t2 <- Sys.time()
  runtime_curr <- round(difftime(t2, t1, units = "mins"), 1)

  # feature importance
  importance = xgboost::xgb.importance(colnames(dtrain), model = model)
  names(importance) <- c("variable", "gain", "cover", "frequency")
  importance$importance <- importance$gain
  importance <- importance %>% dplyr::arrange(dplyr::desc(importance))

  # plot importance
  p <- importance %>%
    utils::head(30) %>%
    ggplot2::ggplot(ggplot2::aes(
      x = importance,
      y = forcats::fct_reorder(variable, importance)
    )) +
    ggplot2::geom_col(color = "white", fill = "grey") +
    ggplot2::ylab("variable") +
    ggplot2::ggtitle("ML-feature-importance") +
    ggplot2::theme_minimal()


  # log
  log_info_if(log, paste0("done, ", "training time=", runtime_curr, " min"))


  # return result -----------------------------------------------------------

  model <- list(
    model = model,
    importance = importance,
    plot = p,
    tune_data = hp_tuning_best,
    tune_plot = plot_hp_tuning
  )

  # output
  if (out %in% c("all", "list")) {
    return(model)
  } else if(out == "model") {
    return(model$model)
  } else if(out == "importance") {
    return(model$importance)
  }

  # default output
  model$plot

} # explain_xgboost

Try the explore package in your browser

Any scripts or data that you put into this service are public.

explore documentation built on Sept. 11, 2024, 7:40 p.m.