train_on_grid: Train several models with different hyperparameters and...

Description Usage Arguments Value

View source: R/grid.R

Description

Tune a model based on a provided hyperparameter grid by cross-validation. Select the best one, optionally including the null model in the comparison.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
train_on_grid(
  mod_spec,
  hyper_param_grid,
  mod_rec,
  training_data,
  outcome,
  cv_nfolds,
  cv_nreps = 1,
  strata = NULL,
  id_col = NULL,
  metric,
  selection_method = "Breiman",
  simplicity_params = NULL,
  include_nullmod = TRUE,
  err_if_nullmod = FALSE,
  warn_if_nullmod = TRUE,
  n_cores = 1
)

Arguments

mod_spec

A parsnip model specification. It must include the model mode and engine. See parsnip::set_mode() and parsnip::set_engine().

hyper_param_grid

A data frame with one row per hyperparameter combination. The column names give the hyper parameter names. Can optionally be passed as a list which is made into a tibble by tidyr::expand_grid().

mod_rec

The recipe for preparing the data for this model. See recipes::recipe().

training_data

A data frame. The data used to train the model.

outcome

A string. The name of the outcome variable. This must be a column in training_data.

cv_nfolds

A positive integer. The number of folds for cross-validation.

cv_nreps

A positive integer. The number of repeated rounds in the cross-validation.

strata

A string. Variable to stratify on when splitting for cross-validation.

id_col

A string. If there is a sample identifier column, specify it here to tell the model not to use it as a predictor.

metric

A string. The metric to use to evaluate the models and select the best one. Common choices are "rmse", "mae", "roc_auc", "accuracy", "mn_log_loss". This should be a metric that is available in the yardstick package, but use e.g. "mae" and not "yardstick::mae" in this argumentIf you specify this as a multi-element character vector, the first element will be used to select the best model; subsequent metrics will also be reported for that model in the cv_performance attribute of the the returned object.)'

selection_method

A string. How to select the best model. There are two options: "Breiman" and "absolute". "absolute" selects the best model by selecting the model with the best mean performance according to the chosen metric. "Breiman" selects the simplest model that comes within one standard deviation of the best score. The idea being that simple models generalize better, so it's better to select a simple model that had near-best performance.

simplicity_params

A character vector. For selection_method = "Breiman". These are passed directly to tune::select_by_one_std_err() and used to sort hyper_param_grid by simplicity. To sort descending, put a minus in front of the parameter. For example, to sort ascending on "x" and then descending on "y", use simplicity_params = c("x", "-y"). See tune::select_by_one_std_err() for details.

include_nullmod

A bool. Include the null model (predicts mean or most common class every time) in the model comparison? This is recommended. If the null model comes within a standard deviation of the otherwise best model, the null model is chosen instead.

err_if_nullmod

A bool. If the null model is chosen, throw an error rather than returning the null model.

warn_if_nullmod

A bool. Warn if returning the null model?

n_cores

A positive integer. The cross-validation can optionally be done in parallel. Specify the number of cores for parallel processing here.

Value

A parsnip model_fit object with a predict() method and recipe and cv_performance attributes.


mirvie/mirmodels documentation built on Jan. 14, 2022, 11:12 a.m.