README.md

tidycrossval

Tidycrossval is the beginnings of a package that deals with hyperparameter tuning and cross validation using tidymodel principles. Currently, this package is mainly designed for my own analysis purposes, but already contains some handy functions that enable hyperparameter tuning and cross validation. The package is designed to integrate with the recipes, parsnip and rsample packages.

Installation

The package is not available on CRAN, and there are no plans to submit in the near-future. The pakage can however but installed directly from this GitHub page by:

devtools::install_github("stevenpawley/tidycrossval")

Basic Usage

The package curently contains three main functions, tune, cross_validate and select_best.

Tuning

tune is a generic function that accepts a tibble of resampling partitions generated by the resampling schemes available in the rsample package. A parsnip model specification and a recipes recipe also need to be supplied to this function, along with a yardstick scoring function. The param_grid function accepts a grid_regular or grid_random object of tuning parameters generated by the dials package. Hyperparameter tuning is then performed on this object, which is returned as a tibble with an additional list-column called tune_scores containing the tuning scores.

If the resampling object represents a nested_cv object, then hyperparameter tuning is performed on the inner resampling partitions, and best best hyperparameters per outer fold are also returned as additional columns in the output tibble.

Tuning recipes parameters

One aspect currently missing from the tidymodels approach is the ability to tune parameters in the recipes object. For example, we might be performing correlation feature selection using thee step_corr function in the recipes packge, but we want to tune the correlation coefficient (i.e. threshold) parameter in the recipe as part of nested cross validation. This is available in the tidycrossval package using the following approach:

First we define a recipes object. A subtle different here is that we have to explicitly define the identifiers of the steps rather nthan using the recipes autogenerated identifier so that we know which hyperparameters refer to which recipe:

library(tidymodels)
library(tidyverse)
library(tidycrossval)

data(iris)

rec <- iris %>%
  recipe(Species ~ .) %>%
  step_scale(all_predictors(), id = "scale") %>%
  step_center(all_predictors(), id = "center") %>%
  step_corr(all_predictors(), threshold = varying(), id = "correlation_filter")

Next we can define a nested cross validation scheme using the rsample package, and also define a parsnip model specification:

clf <- rand_forest(mode = "classification", min_n = varying()) %>% 
  set_engine("ranger")

folds <- iris %>% nested_cv(outside = vfold_cv(v = 5), inside = mc_cv(times = 1))

Now we need to define a hyperparameter grid for model tuning. However, we are also going to tune the threshold parameter in the recipe which we previously set to varying(). To do this, we need to define a new tuning parameter that has to be related to the name of the step in the recipe. A quick approach is to use the convention in scikit-learn where the name consists of the step name, followed by a double underscore and the parameter name, e.g. correlation_filter__threshold. Using dials we can apply this convention to the label field when creating a new tuning parameter:

threshold <- new_quant_param(
  type = "double",
  range = c(0.7, 1.0),
  inclusive = c(TRUE, TRUE),
  trans = NULL,
  label = c(correlation_filter__threshold = "threshold"))

Now we can define a tuning grid:

params <- grid_regular(
  min_n(c(1, 6)), 
  threshold %>% range_set(c(0.8, 1.0))
  )

Finally to perform hyperparameter tuning:

scores <- folds %>% 
  tune(object = clf, recipe = rec, param_grid = params, scoring = accuracy, maximize = TRUE)

This will fit and score the inner folds, and will return a tibble with a tune_scores list-column which contains the scores for each inner resampling partition and hyperparameter set, as wel as columns with the best scoring hyperparameters per outer fold.

Cross Validation

We can pass our scoring tibble to the cross_validate function to fit and score the outer folds:

scores <- scores %>%
  cross_validate(object = clf, recipe = rec, scoring = metric_set(accuracy, f_meas))

This will return a new tibble with an outer_scores list-column containing the scores for the outer folds, based on the hyperparameters that were selected for that fold using the inner resampling partitions.

Selecting the Best Hyperparameters

After we have tuning and cross validated our model, we will usually want to refit our model on all of our training data, using the best average scoring hyperparameters. We can use the select_best function in tidycrossval to create a new parsnip model specification with the best hyperparameters:

clf_tuned <- clf %>% update(min_n = select_best(scores)$min_n)
rec <- rec %>% update(corr__threshold = select_best(scores)$corr__threshold)

Now we can train our model with the best hyperparameters on all of the data:

rec <- prep(rec)
clf_fitted <- fit(clf_tuned, formula(rec), data = iris)

Using pipelines

Instead of manually keeping track of recipe and model_spec objects, a pipeline object can be used. This is especially useful if we are tuning hyperparameters within both the recipe and model specification. An example of using pipelines is:

# create a new model specification and pipeline with the preprocessing recipe
clf <- decision_tree(mode = "classification", tree_depth = varying()) %>%
  set_engine("rpart")

clf <- pipeline(rec, clf)

# perform hyperparameter tuning and cross_validation
scores <- folds %>%
  tune(object = clf, param_grid = params, scoring = accuracy, maximize = TRUE)

scores <- scores %>%
  cross_validate(object = clf, recipe = rec, scoring = metric_set(accuracy, f_meas))

# update the pipeline with the best hyperparameters
clf_tuned <- clf %>%
  update(!!!select_best(scores)) %>%
  fit(data = iris)

# predict new data
predict(clf_tuned, new_data = iris)


stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.