knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  message = FALSE,
  warning = FALSE,
  eval = FALSE
)

If your hyperparameter search is likely to take a long time to run, you might want to be able to save intermediate results wile the search is in progress.

Prepare the recipe:

data("attrition")

attrition %<>% mutate(Attrition = ifelse(Attrition == 'Yes', 1, 0))

resamples <- rsample::vfold_cv(attrition, v = 5)

rec <- 
  recipe(attrition) %>%
  add_role(Attrition, new_role = 'outcome') %>%
  add_role(-Attrition, new_role = 'predictor') %>%
  step_novel(all_nominal(), -Attrition) %>%
  step_dummy(all_nominal(), -Attrition) %>%
  step_zv(all_predictors())

Prepare your scoring function:

xgboost_classif_score <- 
  function(train_df, 
           target_var, 
           params, 
           eval_df, 
           ...){

  X_train <- train_df %>% select(-matches(target_var)) %>% as.matrix()
  y_train <- train_df[[target_var]]
  xgb_train_data <- xgb.DMatrix(X_train, label = y_train)

  X_eval <- eval_df %>% select(-matches(target_var)) %>% as.matrix()
  y_eval <- eval_df[[target_var]]
  xgb_eval_data <- xgb.DMatrix(X_eval, label = y_eval)

  model <- xgb.train(params = params,
                     data = xgb_train_data,
                     watchlist = list(train = xgb_train_data, eval = xgb_eval_data),
                     objective = 'binary:logistic',
                     verbose = FALSE,
                     ...)

  preds <- predict(model, xgb_eval_data)

  list(score1 = LogLoss(preds, y_eval), 
       score2 = Accuracy(ifelse(preds > 0.5, 1, 0), y_eval))

  # You can also return a simple vector score:
  # LogLoss(preds, y_eval)
}
set.seed(123)

xgboost_param_grid <- expand.grid(eta = c(0.1, 0.05), max_depth = c(3, 4))

results_grid_search <- 
  grid_search(
    resamples = resamples, 
    recipe = rec, 
    param_grid = xgboost_param_grid, 
    scoring_func = xgboost_classif_score, 
    nrounds = 100,
    n = 4,
    batch_size = 2
  )


artichaud1/tidytune documentation built on May 20, 2019, 9:13 p.m.