fit_resamples: Fit multiple models via resampling

View source: R/resample.R

fit_resamplesR Documentation

Fit multiple models via resampling

Description

fit_resamples() computes a set of performance metrics across one or more resamples. It does not perform any tuning (see tune_grid() and tune_bayes() for that), and is instead used for fitting a single model+recipe or model+formula combination across many resamples.

Usage

fit_resamples(object, ...)

## S3 method for class 'model_spec'
fit_resamples(
  object,
  preprocessor,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

## S3 method for class 'workflow'
fit_resamples(
  object,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

Arguments

object

A parsnip model specification or a workflows::workflow(). No tuning parameters are allowed.

...

Currently unused.

preprocessor

A traditional model formula or a recipe created using recipes::recipe().

resamples

A resample rset created from an rsample function such as rsample::vfold_cv().

metrics

A yardstick::metric_set(), or NULL to compute a standard set of metrics.

control

A control_resamples() object used to fine tune the resampling process.

Performance Metrics

To use your own performance metrics, the yardstick::metric_set() function can be used to pick what should be measured for each model. If multiple metrics are desired, they can be bundled. For example, to estimate the area under the ROC curve as well as the sensitivity and specificity (under the typical probability cutoff of 0.50), the metrics argument could be given:

  metrics = metric_set(roc_auc, sens, spec)

Each metric is calculated for each candidate model.

If no metric set is provided, one is created:

  • For regression models, the root mean squared error and coefficient of determination are computed.

  • For classification, the area under the ROC curve and overall accuracy are computed.

Note that the metrics also determine what type of predictions are estimated during tuning. For example, in a classification problem, if metrics are used that are all associated with hard class predictions, the classification probabilities are not created.

The out-of-sample estimates of these metrics are contained in a list column called .metrics. This tibble contains a row for each metric and columns for the value, the estimator type, and so on.

collect_metrics() can be used for these objects to collapse the results over the resampled (to obtain the final resampling estimates per tuning parameter combination).

Obtaining Predictions

When control_grid(save_pred = TRUE), the output tibble contains a list column called .predictions that has the out-of-sample predictions for each parameter combination in the grid and each fold (which can be very large).

The elements of the tibble are tibbles with columns for the tuning parameters, the row number from the original data object (.row), the outcome data (with the same name(s) of the original data), and any columns created by the predictions. For example, for simple regression problems, this function generates a column called .pred and so on. As noted above, the prediction columns that are returned are determined by the type of metric(s) requested.

This list column can be unnested using tidyr::unnest() or using the convenience function collect_predictions().

Extracting Information

The extract control option will result in an additional function to be returned called .extracts. This is a list column that has tibbles containing the results of the user's function for each tuning parameter combination. This can enable returning each model and/or recipe object that is created during resampling. Note that this could result in a large return object, depending on what is returned.

The control function contains an option (extract) that can be used to retain any model or recipe that was created within the resamples. This argument should be a function with a single argument. The value of the argument that is given to the function in each resample is a workflow object (see workflows::workflow() for more information). Several helper functions can be used to easily pull out the preprocessing and/or model information from the workflow, such as extract_preprocessor() and extract_fit_parsnip().

As an example, if there is interest in getting each parsnip model fit back, one could use:

  extract = function (x) extract_fit_parsnip(x)

Note that the function given to the extract argument is evaluated on every model that is fit (as opposed to every model that is evaluated). As noted above, in some cases, model predictions can be derived for sub-models so that, in these cases, not every row in the tuning parameter grid has a separate R object associated with it.

See Also

control_resamples(), collect_predictions(), collect_metrics()

Examples


library(recipes)
library(rsample)
library(parsnip)
library(workflows)

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

spline_rec <- recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp) %>%
  step_ns(wt)

lin_mod <- linear_reg() %>%
  set_engine("lm")

control <- control_resamples(save_pred = TRUE)

spline_res <- fit_resamples(lin_mod, spline_rec, folds, control = control)

spline_res

show_best(spline_res, metric = "rmse")

# You can also wrap up a preprocessor and a model into a workflow, and
# supply that to `fit_resamples()` instead. Here, a workflows "variables"
# preprocessor is used, which lets you supply terms using dplyr selectors.
# The variables are used as-is, no preprocessing is done to them.
wf <- workflow() %>%
  add_variables(outcomes = mpg, predictors = everything()) %>%
  add_model(lin_mod)

wf_res <- fit_resamples(wf, folds)


tune documentation built on Aug. 24, 2023, 1:09 a.m.