Generalized Cross-Validation with Origami


Cross-validation is an essential tool for evaluating how any given data analytic procedure extends from a sample to the target population from which the sample is derived. It has seen widespread application in all facets of statistics, perhaps most notably statistical machine learning. When used for model selection, cross-validation has powerful optimality properties [@vaart2006oracle,@vdl2007super].

Cross-validation works by partitioning a sample into complementary subsets, applying a particular data analytic (statistical) routine on a subset (the "training" set), and evaluating the routine of choice on the complementary subset (the "testing" set). This procedure is repeated across multiple partitions of the data. A variety of different partitioning schemes exist, such as V-fold cross-validation and bootstrap cross-validation, many of which are supported by origami. The origami package provides a suite of tools that generalize the application of cross-validation to arbitrary data analytic procedures. The use of origami is best illustrated by example.

Cross-validation with linear regression

We'll start by examining a fairly simple data set:


One might be interested in examining how the efficiency of a car, as measured by miles-per-gallon (mpg), is explained by various technical aspects of the car, with data across a variety of different models of cars. Linear regression is perhaps the simplest statistical procedure that could be used to make such deductions. Let's try it out:

lm_mod <- lm(mpg ~ ., data = mtcars)

We can assess how well the model fits the data by comparing the predictions of the linear model to the true outcomes observed in the data set. This is the well known (and standard) mean squared error. We can extract that from the lm model object like so:

err <- mean(resid(lm_mod)^2)

The mean squared error is r err. There is an important problem that arises when we assess the model in this way -- that is, we have trained our linear regression model on the full data set and assessed the error on the full data set, using up all of our data. We, of course, are generally not interested in how well the model explains variation in the observed data; rather, we are interested in how the explanation provided by the model generalizes to a target population from which the sample is presumably derived. Having used all of our available data, we cannot honestly evaluate how well the model fits (and thus explains) variation at the population level.

To resolve this issue, cross-validation allows for a particular procedure (e.g., linear regression) to be implemented over subsets of the data, evaluating how well the procedure fits on a testing ("validation") set, thereby providing an honest evaluation of the error.

We can easily add cross-validation to our linear regression procedure using origami. First, let us define a new function to perform linear regression on a specific partition of the data (called a "fold"):

cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(coef = data.frame(t(coef(mod))),
              SE = ((preds - valid_data[, out_var_ind])^2))

Our cv_lm function is rather simple: we merely split the available data into a training and validation sets, using the eponymous functions provided in origami, fit the linear model on the training set, and evaluate the model on the testing set. This is a simple example of what origami considers to be cv_funs -- functions for using cross-validation to perform a particular routine over an input data set. Having defined such a function, we can simply generate a set of partitions using origami's make_folds function, and apply our cv_lm function over the resultant folds object. Below, we replicate the resubstitution estimate of the error -- we did this "by hand" above -- using the functions make_folds and cv_lm.

library(stringr) # used in defining the cv_lm function above
# resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")

This (very nearly) matches the estimate of the error that we obtained above.

We can more honestly evaluate the error by V-fold cross-validation, which partitions the data into v subsets, fitting the model on $v - 1$ of the subsets and evaluating on the subset that was held out for testing. This is repeated such that each subset is used for testing. We can easily apply our cv_lm function using origami's cross_validate (n.b., by default this performs 10-fold cross-validation):

# cross-validated estimate
folds <- make_folds(mtcars)
cvlm_results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
                               reg_form = "mpg ~ .")

Having performed 10-fold cross-validation, we quickly notice that our previous estimate of the model error (by resubstitution) was quite optimistic. The honest estimate of the error is several times larger.

General workflow

Generally, cross_validate usage will mirror the workflow in the above example. First, the user must define folds and a function that operates on each fold. Once these are passed to cross_validate, the function will map the function across the folds, and combine the results in a reasonable way. More details on each step of this process will be given below.

Define folds

The folds object passed to cross_validate is a list of folds. Such lists can be generated using the make_folds function. Each fold consists of a list with a training index vector, a validation index vector, and a fold_index (its order in the list of folds). This function supports a variety of cross-validation schemes including v-fold and bootstrap cross-validation as well as time series methods like "Rolling Window". It can balance across levels of a variable (stratify_ids), and it can also keep all observations from the same independent unit together (cluster_ids). See the documentation of the make_folds function for details about supported cross-validation schemes and arguments.

Define fold function

The cv_fun argument to cross_validate is a function that will perform some operation on each fold. The first argument to this function must be fold, which will receive an individual fold object to operate on. Additional arguments can be passed to cv_fun using the ... argument to cross_validate. Within this function, the convenience functions training, validation and fold_index can return the various components of a fold object. They do this by retrieving the fold object from their calling environment. It can also be specified directly. If training or validation is passed an object, it will index into it in a sensible way. For instance, if it is a vector, it will index the vector directly. If it is a data.frame or matrix, it will index rows. This allows the user to easily partition data into training and validation sets. This fold function must return a named list of results containing whatever fold-specific outputs are generated.

Apply cross_validate

After defining folds, cross_validate can be used to map the cv_fun across the folds using future_lapply. This means that it can be easily parallelized by specifying a parallelization scheme (i.e., a plan). See the future package for more details.

The application of cross_validate generates a list of results. As described above, each call to cv_fun itself returns a list of results, with different elements for each type of result we care about. The main loop generates a list of these individual lists of results (a sort of "meta-list"). This "meta-list" is then inverted such that there is one element per result type (this too is a list of the results for each fold). By default, combine_results is used to combine these results type lists.

For instance, in the above mtcars example, the results type lists contains one coef data.frame from each fold. These are rbinded together to form one data.frame containing the coefs from all folds in different rows. How results are combined is determined automatically by examining the data types of the results from the first fold. This can be modified by specifying a list of arguments to .combine_control. See the help for combine_results for more details. In most cases, the defaults should suffice.

Cross-validation with random forests

To examine origami further, let us return to our example analysis using the mtcars data set. Here, we will write a new cv_fun type object. As an example, we will use L. Breiman's randomForest:

cv_rf <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # define training and validation sets based on input object of class "folds"
  train_data <- training(data)
  valid_data <- validation(data)

  # fit Random Forest regression on training set and predict on holdout set
  mod <- randomForest(formula = as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # define output object to be returned as list (for flexibility)
  out <- list(coef = data.frame(mod$coefs),
              SE = ((preds - valid_data[, out_var_ind])^2))

Above, in writing our cv_rf function to cross-validate randomForest, we used our previous function cv_lm as an example. For now, individual cv_funs must be written by hand; however, in future releases, a wrapper may be available to support auto-generating cv_funs to be used with origami.

Below, we use cross_validate to apply our new cv_rf function over the folds object generated by make_folds.

folds <- make_folds(mtcars)
cvrf_results <- cross_validate(cv_fun = cv_rf, folds = folds, data = mtcars,
                               reg_form = "mpg ~ .")

Using 10-fold cross-validation (the default), we obtain an honest estimate of the prediction error of random forests. From this, we gather that the use of origami's cross_validate procedure can be generalized to arbitrary esimation techniques, given availability of an appropriate cv_fun function.

Cross-validation with dependence: time series

Cross-validation can also be used for forecast model selection in a time series setting. Here, the partitioning scheme mirrors the application of the forecasting model: We'll train the data on past observations (either all available or a recent subset), and then use the model forecast (predict), the next few observations. Consider the AirPassengers dataset, a monthly time series of passenger air traffic in thousands of people.


Suppose we want to pick between two forecasting models, stl, and arima (the details of these models are not important for this example). We can do that by evaluating their forecasting performance.

folds = make_folds(AirPassengers, fold_fun=folds_rolling_origin,
                   first_window = 36, validation_size = 24)
fold = folds[[1]]

# function to calculate cross-validated squared error
cv_forecasts <- function(fold, data) {
  train_data <- training(data)
  valid_data <- validation(data)
  valid_size <- length(valid_data)

  train_ts <- ts(log10(train_data), frequency = 12)

  # borrowed from AirPassengers help
  arima_fit <- arima(train_ts, c(0, 1, 1),
                     seasonal = list(order = c(0, 1, 1),
                                     period = 12))
  raw_arima_pred <- predict(arima_fit, n.ahead = valid_size)
  arima_pred <- 10^raw_arima_pred$pred
  arima_MSE <- mean((arima_pred - valid_data)^2)

  # stl model
  stl_fit <- stlm(train_ts, s.window = 12)
  raw_stl_pred = forecast(stl_fit, h = valid_size)
  stl_pred <- 10^raw_stl_pred$mean
  stl_MSE <- mean((stl_pred - valid_data)^2)

  out <- list(mse = data.frame(fold = fold_index(),
                               arima = arima_MSE, stl = stl_MSE))

mses = cross_validate(cv_fun = cv_forecasts, folds = folds,
                      data = AirPassengers)$mse
colMeans(mses[, c("arima", "stl")])

Session Information



Try the origami package in your browser

Any scripts or data that you put into this service are public.

origami documentation built on March 18, 2018, 1:25 p.m.