learn_curve_cv: Create a cross-validation learning curve.

Description Usage Arguments Value Examples

Description

This function needs a workflows::workflow ready for tune::fit_resamples. It does different fold cross-validation to vary the training set sizes and then collects the predictions and scores them.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
learn_curve_cv(
  data,
  wf,
  folds,
  repeats,
  metric_calculator,
  strata = NULL,
  pkgs = c("mirmodels"),
  n_cores = 1
)

Arguments

data

A data frame. The data to be used for the modelling.

wf

A workflows::workflow(). Should have been constructed using data.

folds

An integer vector. Different v to use in rsample::vfold_cv().

repeats

The number of times to repeat each cross-validation.

metric_calculator

A function which takes a single data frame argument and returns a double. The data frame that will be passed to this function is the output of tune::collect_predictions() which will be run on the output of tune::fit_resamples(save_preds = TRUE). See the example below.

strata

A string. Variable to stratify on when splitting for cross-validation.

pkgs

A character vector. Passed to tune::control_resamples().

n_cores

A positive integer. The cross-validation can optionally be done in parallel. Specify the number of cores for parallel processing here.

Value

A tibble with 2 columns:

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
data("BostonHousing", package = "mlbench")
bh <- dplyr::select_if(BostonHousing, is.numeric)
mod <- parsnip::linear_reg(penalty = 0, mixture = 0) %>%
  parsnip::set_engine("lm")
wf <- workflows::workflow() %>%
  workflows::add_formula(medv ~ .) %>%
  workflows::add_model(mod)
metric_calculator <- ~ yardstick::mae(., medv, .pred)$.estimate
lccv <- suppressWarnings(
  learn_curve_cv(bh, wf, 2:9, 3, metric_calculator, n_cores = 4)
)

mirvie/mirmodels documentation built on Jan. 14, 2022, 11:12 a.m.