data-raw/01-test_pokemon_multi.R

library(tidyverse)
library(xgbh)
library(recipes)
library(yardstick)

nms <- c('pokemon_multi_trn', 'pokemon_multi_tst')
data(list = nms)
df_trn <- pokemon_multi_trn
df_tst <- pokemon_multi_tst
rm(list = nms)
rm('nms')

col_y <- 'type_1'
col_id <- 'number'

form <- paste0(col_y, '~ .') %>% as.formula()
rec_init <-
  df_trn %>%
  select(-c(name)) %>%
  recipe(form, data = .) %>%
  update_role(c(number), new_role = 'id') %>%
  step_rm(any_of(c('type_2', 'total'))) %>%
  # step_dummy(generation) %>%
  step_nzv(all_numeric_predictors())
rec_init

# checking ----
jui_trn <- rec_init %>% prep() %>% juice()
jui_trn

.to_xgb <- function(df) {
  jui <- rec_init %>% prep() %>% bake(new_data = df)
  jui %>% mutate(across(all_of(col_y), ~as.integer(.x) - 1L))
}
jui_trn_xgb <- df_trn %>% .to_xgb()
jui_trn_xgb
jui_tst_xgb <- df_tst %>% .to_xgb()
jui_tst_xgb

set.seed(42)
col_y_sym <- sym(col_y)
n_class <- jui_trn %>% distinct(!!col_y_sym) %>% pull(!!col_y_sym) %>% length()
do_fit_partially <- partial(
  xgbh::do_fit,
  overwrite = FALSE,
  objective = 'multi:softprob',
  eval_metrics = list('mlogloss'),
  # eval_metrics = list('error'),
  num_class = n_class,
  col_y = col_y,
  col_id = col_id,
  ... =
)

suffix <- 'pokemon_nodummy_robust'
do_fit_robustly <- partial(
  do_fit_partially,
  grid_params =
    jui_trn_xgb %>%
    select(-any_of(c(col_y, col_id))) %>%
    generate_grid_params(n_param = 40),
  # n_param = 50,
  nrounds = 1000,
  suffix = suffix,
  ... =
)
do_fit_robustly_timely <- time_it(do_fit_robustly)

c(tune, fit) %<-%
  do_fit_robustly_timely(
    data = jui_trn_xgb
  )
fit

# library(tidyverse)
n_round <- fit$best_iteration
n_class <- fit$params$num_class
imp <- xgboost::xgb.importance(model = fit)
imp1 <- xgboost::xgb.importance(model = fit, trees = seq(from = 0, by = n_class, length.out = n_round))
xgboost::xgb.ggplot.importance(imp) +
  tonythemes::theme_tony()
xgboost::xgb.ggplot.importance(imp1) +
  tonythemes::theme_tony()

do_predict_partially <- partial(
  do_predict,
  # overwrite = FALSE,
  overwrite = TRUE,
  fit = fit,
  use_y = TRUE,
  col_y = col_y,
  col_id = col_id,
  ... =
)
do_predict_timely <- time_it(do_predict_partially)

c(probs_trn, shap_trn) %<-%
  do_predict_timely(
    data = jui_trn_xgb,
    suffix = sprintf('%s_trn', suffix)
  )
probs_trn

c(probs_tst, shap_tst) %<-%
  do_predict_timely(
    data = jui_tst_xgb,
    suffix = sprintf('%s_tst', suffix)
  )
probs_tst

met_set <- yardstick::metric_set(yardstick::mn_log_loss, yardstick::roc_auc)
do_eval <- function(probs) {
  probs %>%
    mutate(across(all_of(col_y), factor)) %>%
    met_set(!!sym(col_y), matches('^[.]prob'))
}

probs_trn %>% do_eval()
probs_tst %>% do_eval()

do_plot_shap(
  shap = shap_trn,
  suffix = sprintf('%s_trn', suffix),
  overwrite = TRUE,
  col_y = col_y,
  col_id = col_id
)
tonyelhabr/xgbh documentation built on Dec. 23, 2021, 11:59 a.m.