Description Usage Arguments Value See Also Examples
Train a gradient boosted model,
selecting hyperparameters trees
, tree_depth
and learn_rate
by cross-validation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | train_gbm(
training_data,
outcome,
metric,
hyper_param_grid = list(trees = c(5, 10, 20, 40), tree_depth = c(1, 2), learn_rate =
c(0.01, 0.1, 0.2, 0.5)),
cv_nfolds = 5,
cv_nreps = 1,
id_col = NULL,
strata = NULL,
selection_method = "Breiman",
simplicity_params = NULL,
include_nullmod = TRUE,
err_if_nullmod = FALSE,
warn_if_nullmod = TRUE,
n_cores = 1
)
|
training_data |
A data frame. The data used to train the model. |
outcome |
A string. The name of the outcome variable. This must be a
column in |
metric |
A string. The probability metric to choose the best model.
and select the best one. Common choices are |
hyper_param_grid |
A data frame with one row per hyperparameter
combination. The column names give the hyper parameter names. Can
optionally be passed as a list which is made into a tibble by
|
cv_nfolds |
A positive integer. The number of folds for cross-validation. |
cv_nreps |
A positive integer. The number of repeated rounds in the cross-validation. |
id_col |
A string. If there is a sample identifier column, specify it here to tell the model not to use it as a predictor. |
strata |
A string. Variable to stratify on when splitting for cross-validation. |
selection_method |
A string. How to select the best model. There are two options: "Breiman" and "absolute". "absolute" selects the best model by selecting the model with the best mean performance according to the chosen metric. "Breiman" selects the simplest model that comes within one standard deviation of the best score. The idea being that simple models generalize better, so it's better to select a simple model that had near-best performance. |
simplicity_params |
A character vector. For |
include_nullmod |
A bool. Include the null model (predicts mean or most common class every time) in the model comparison? This is recommended. If the null model comes within a standard deviation of the otherwise best model, the null model is chosen instead. |
err_if_nullmod |
A bool. If the null model is chosen, throw an error rather than returning the null model. |
warn_if_nullmod |
A bool. Warn if returning the null model? |
n_cores |
A positive integer. The cross-validation can optionally be done in parallel. Specify the number of cores for parallel processing here. |
A parsnip::model_fit object. To use this fitted model mod
to make
predictions on some new data df_new
, use
predict(mod, new_data = df_new)
.
Other model trainers:
train_glm()
,
train_lm()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | iris_data <- janitor::clean_names(datasets::iris)
iris_data_split <- rsample::initial_split(iris_data, strata = species)
iris_training_data <- rsample::training(iris_data_split)
iris_testing_data <- rsample::testing(iris_data_split)
mod <- train_gbm(
training_data = iris_training_data, outcome = "species",
metric = "mn_log_loss",
hyper_param_grid = list(
trees = c(20, 50),
tree_depth = c(1, 2),
learn_rate = c(0.01, 0.1)
),
simplicity_params = c("trees", "learn_rate"),
strata = c("species"),
n_cores = 5
)
preds <- predict(mod, new_data = iris_testing_data, type = "prob")
dplyr::bind_cols(preds, truth = iris_testing_data$species)
yardstick::mn_log_loss_vec(
truth = iris_testing_data$species,
estimate = as.matrix(preds)
)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.