cross_validate: Main Cross-Validation Function

Description Usage Arguments Value Examples

View source: R/cross_validate.R

Description

Applies cv_fun to the folds using future_lapply and combines the results across folds using combine_results.

Usage

1
2
cross_validate(cv_fun, folds, ..., use_future = TRUE, .combine = TRUE,
  .combine_control = list(), .old_results = NULL)

Arguments

cv_fun

a function that takes a 'fold' as it's first argument and returns a list of results from that fold. NOTE: the use of an argument named 'X' is specifically disallowed in any input function for compliance with the functions lapply and future.apply::future_lapply.

folds

a list of folds to loop over generated using make_folds.

...

other arguments passed to cvfun.

use_future

logical option for whether to run the main loop of cross-validation with future_lapply or with lapply.

.combine

(logical) - should combine_results be called.

.combine_control

(list) - arguments to combine_results.

.old_results

(list) - the returned result from a previous call to This function. Will be combined with the current results. This is useful for adding additional CV folds to a results object.

Value

A list of results, combined across folds.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
###############################################################################
# This example explains how to use the cross_validate function naively.
###############################################################################
data(mtcars)

# resubstitution MSE
r <- lm(mpg ~ ., data = mtcars)
mean(resid(r)^2)

# function to calculate cross-validated squared error
cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(stringr::str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(coef = data.frame(t(coef(mod))),
              SE = ((preds - valid_data[, out_var_ind])^2))
  return(out)
}

# replicate the resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)

# cross-validated estimate
folds <- make_folds(mtcars)
cv_results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
                             reg_form = "mpg ~ .")
mean(cv_results$SE)
###############################################################################
# This example explains how to use the cross_validate function with
# parallelization using the framework of the future package.
###############################################################################

suppressMessages(library(data.table))
library(future)
data(mtcars)
set.seed(1)

# make a lot of folds
folds <- make_folds(mtcars, fold_fun = folds_bootstrap, V = 1000)

# function to calculate cross-validated squared error for linear regression
cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(coef = data.frame(t(coef(mod))),
              SE = ((preds - valid_data[, out_var_ind])^2))
  return(out)
}

plan(sequential)
time_seq <- system.time({
    results_seq <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
                                  reg_form = "mpg ~ .")
})

plan(multicore)
time_mc <- system.time({
    results_mc <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
                                 reg_form = "mpg ~ .")
})

if(availableCores() > 1) {
    time_mc["elapsed"] < 1.2 * time_seq["elapsed"]
}

origami documentation built on March 18, 2018, 1:25 p.m.