train_gbm: Train a gradient boosted model.

Description Usage Arguments Value See Also Examples

View source: R/gbm.R

Description

Train a gradient boosted model, selecting hyperparameters trees, tree_depth and learn_rate by cross-validation.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
train_gbm(
  training_data,
  outcome,
  metric,
  hyper_param_grid = list(trees = c(5, 10, 20, 40), tree_depth = c(1, 2), learn_rate =
    c(0.01, 0.1, 0.2, 0.5)),
  cv_nfolds = 5,
  cv_nreps = 1,
  id_col = NULL,
  strata = NULL,
  selection_method = "Breiman",
  simplicity_params = NULL,
  include_nullmod = TRUE,
  err_if_nullmod = FALSE,
  warn_if_nullmod = TRUE,
  n_cores = 1
)

Arguments

training_data

A data frame. The data used to train the model.

outcome

A string. The name of the outcome variable. This must be a column in training_data.

metric

A string. The probability metric to choose the best model. and select the best one. Common choices are "mn_log_loss", "roc_auc" and "accuracy". This should be a metric that is available in the yardstick package, but use e.g. "mae" and not "yardstick::mae" in this argument. If you specify this as a multi-element character vector, the first element will be used to select the best model; subsequent metrics will also be reported for that model in the cv_performance attribute of the the returned object.)'

hyper_param_grid

A data frame with one row per hyperparameter combination. The column names give the hyper parameter names. Can optionally be passed as a list which is made into a tibble by tidyr::expand_grid().

cv_nfolds

A positive integer. The number of folds for cross-validation.

cv_nreps

A positive integer. The number of repeated rounds in the cross-validation.

id_col

A string. If there is a sample identifier column, specify it here to tell the model not to use it as a predictor.

strata

A string. Variable to stratify on when splitting for cross-validation.

selection_method

A string. How to select the best model. There are two options: "Breiman" and "absolute". "absolute" selects the best model by selecting the model with the best mean performance according to the chosen metric. "Breiman" selects the simplest model that comes within one standard deviation of the best score. The idea being that simple models generalize better, so it's better to select a simple model that had near-best performance.

simplicity_params

A character vector. For selection_method = "Breiman". These are passed directly to tune::select_by_one_std_err() and used to sort hyper_param_grid by simplicity. To sort descending, put a minus in front of the parameter. For example, to sort ascending on "x" and then descending on "y", use simplicity_params = c("x", "-y"). See tune::select_by_one_std_err() for details.

include_nullmod

A bool. Include the null model (predicts mean or most common class every time) in the model comparison? This is recommended. If the null model comes within a standard deviation of the otherwise best model, the null model is chosen instead.

err_if_nullmod

A bool. If the null model is chosen, throw an error rather than returning the null model.

warn_if_nullmod

A bool. Warn if returning the null model?

n_cores

A positive integer. The cross-validation can optionally be done in parallel. Specify the number of cores for parallel processing here.

Value

A parsnip::model_fit object. To use this fitted model mod to make predictions on some new data df_new, use predict(mod, new_data = df_new).

See Also

Other model trainers: train_glm(), train_lm()

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
iris_data <- janitor::clean_names(datasets::iris)
iris_data_split <- rsample::initial_split(iris_data, strata = species)
iris_training_data <- rsample::training(iris_data_split)
iris_testing_data <- rsample::testing(iris_data_split)
mod <- train_gbm(
  training_data = iris_training_data, outcome = "species",
  metric = "mn_log_loss",
  hyper_param_grid = list(
    trees = c(20, 50),
    tree_depth = c(1, 2),
    learn_rate = c(0.01, 0.1)
  ),
  simplicity_params = c("trees", "learn_rate"),
  strata = c("species"),
  n_cores = 5
)
preds <- predict(mod, new_data = iris_testing_data, type = "prob")
dplyr::bind_cols(preds, truth = iris_testing_data$species)
yardstick::mn_log_loss_vec(
  truth = iris_testing_data$species,
  estimate = as.matrix(preds)
)

mirvie/mirmodels documentation built on Jan. 14, 2022, 11:12 a.m.