inst/extdata/scripts/ml_xgboost.R

# load package, paths and variable sets from global.R --------------------------

source("inst/extdata/scripts/ml_resampling.R")
library(tidymodels)


# Specify model ----------------------------------------------------------------

xgb_model <- boost_tree(mtry = 6,
                        min_n = 10,
                        trees = 500,
                        tree_depth = 7,
                        loss_reduction = 10,
                        learn_rate = 0.1,
                        sample_size = 0.7) %>%
  set_engine("xgboost", nthreads = parallel::detectCores()) %>%
  set_mode("regression")


# Model training and assessment (regression) -----------------------------------

# Train model
set.seed(26)
xgb_fit <- xgb_model %>% fit(Qs_rel ~ ., data = df_training)
#usethis::use_data(xgb_fit, compress = "xz", overwrite = TRUE)

# Make predictions
predictions <- predict(xgb_fit, df_test)

# Evaluate model performance
df_pred <- df_test %>% select(Qs_rel) %>% bind_cols(predictions)
rmse(df_pred, truth = Qs_rel, estimate = .pred)
rsq(df_pred, truth = Qs_rel, estimate = .pred)

# scatter plot
scatterplot(df_pred, lines_80perc = FALSE, alpha = 1, pointsize = 0.9)
ggsave("scatterplot_xgb_numeric.png", dpi = 600, width = 3.5, height = 3)



# classification performance ---------------------------------------------------

# classify Qs data
df_pred <- df_pred %>%
  mutate(Qs_rel_class = classify_Qs(Qs_rel),
         .pred_class = classify_Qs(.pred))

# confusion matrix
matrix <- conf_mat(df_pred, truth = Qs_rel_class, estimate = .pred_class)

# performance metrics
metrics <- matrix %>% summary()

save_data(matrix, getwd(), "xgb_numeric_to_class_matrix_split80", formats = "RData")
save_data(metrics, getwd(), "xgb_numeric_to_class_metrics_split80")



# Hyperparameter tuning --------------------------------------------------------

if (FALSE) {

  # specify model
  xgb_model <- boost_tree(
    trees = 500,
    tree_depth = tune(), min_n = tune(),
    loss_reduction = tune(),                     ## first three: model complexity
    sample_size = tune(), mtry = tune(),         ## randomness
    learn_rate = tune(),                         ## step size
  ) %>%
    set_engine("xgboost") %>%
    set_mode("regression")

  # set up workflow
  xgb_wf <- workflow() %>%
    add_formula(Qs_rel ~ .) %>%
    add_model(xgb_model)

  # hyperparameter sampling v1
  xgb_grid <- grid_random(tree_depth(),
                          min_n(),
                          loss_reduction(),
                          sample_size = sample_prop(),
                          finalize(mtry(), df_training),
                          learn_rate(range = c(0.01, 0.1), trans = NULL),
                          size = 1000)

    # hyperparameter sampling v2
  xgb_grid <- grid_latin_hypercube(
    tree_depth(),
    min_n(),
    loss_reduction(),
    sample_size = sample_prop(),
    finalize(mtry(), df_training),
    learn_rate(),
    size = 500
  )

  # define cross validation procedure
  cv_folds <- vfold_cv(df_training, v = 5)

  # set up random grid with 20 combinations for first screening
  doParallel::registerDoParallel()

  # test different hyperparameters via cross validation on training data
  set.seed(234)
  xgb_tuning <- tune_grid(
    xgb_wf,
    resamples = cv_folds,
    grid = xgb_grid,
    control = control_grid(save_pred = TRUE)
  )

  # get assessment metrics
  metrics <- xgb_tuning %>% collect_metrics()
  save_data(metrics, getwd(), "metrics_tuning_xgb_random_resampling")

  # visualise results
  metrics %>%
    #filter(learn_rate > 0.01) %>%
    filter(.metric == "rmse") %>%
    select(mean, min_n, mtry, tree_depth, learn_rate, loss_reduction, sample_size) %>%
    pivot_longer(c(min_n, mtry, tree_depth, learn_rate, loss_reduction, sample_size),
                 values_to = "value",
                 names_to = "parameter"
    ) %>%
    ggplot(aes(value, mean, color = parameter)) +
    geom_point(show.legend = FALSE, size = 0.5) +
    facet_wrap(~parameter, scales = "free") +
    labs(x = NULL, y = "RMSE [%]") +
    sema.berlin.utils::my_theme()

  ggsave("xgb_regression_hyperparameter_tuning_plot_random_resampling_1000_v2.png", width = 8, height = 4, dpi = 600)

  # after example from https://juliasilge.com/blog/xgboost-tune-volleyball/

}
KWB-R/dwc.wells documentation built on July 13, 2022, 9:36 p.m.