validate: Cross-Validated Prediction Metrics

View source: R/validate.R

validateR Documentation

Cross-Validated Prediction Metrics

Description

validate is a generic function for cross-validating predictions from the results of various model fitting functions. The function invokes particular methods which depend on the class of the first argument.

Usage

validate(object, ...)

## S3 method for class 'lm'
validate(
  object,
  data = NULL,
  n_folds = 10,
  n_reps = 10,
  seed = 42,
  silent = FALSE,
  ...
)

## S3 method for class 'glm'
validate(
  object,
  data = NULL,
  n_folds = 10,
  n_reps = 10,
  seed = 42,
  silent = FALSE,
  ...
)

## S3 method for class 'zeroinfl'
validate(object, n_folds = 10, n_reps = 10, seed = 42, silent = FALSE, ...)

## S3 method for class 'glmnet'
validate(
  object,
  x = NULL,
  y = NULL,
  lambda = NULL,
  offset = NULL,
  weights = NULL,
  n_folds = 10,
  n_reps = 10,
  seed = 42,
  envir = .GlobalEnv,
  silent = FALSE,
  ...
)

## S3 method for class 'beset'
validate(object, ...)

## S3 method for class 'nested'
validate(object, metric = "auto", oneSE = TRUE, ...)

## S3 method for class 'randomForest'
validate(
  object,
  data = NULL,
  x = NULL,
  n_folds = 10,
  n_reps = 10,
  seed = 42,
  ...,
  parallel_type = NULL,
  n_cores = NULL,
  cl = NULL,
  silent = FALSE
)

Arguments

object

A model object for which a cross-validated R-squared is desired.

...

Additional arguments passed to model summary methods.

data

Data frame that was used to train the model. Only needed if the training data is not contained in the model object.

n_folds

Integer indicating the number of folds to use for cross-validation.

n_reps

Integer indicating the number of times cross-validation should be repeated.

seed

Integer used to seed the random number generator.

silent

Logical indicating whether to suppress warning messsages related to autocorrection of n_folds and n_reps parameters when they are inappropriate for sample size or cross-validation method

object call

x

Model matrix that was used to train elastic net.

y

Response variable that was used to train elastic net.

lambda

Numeric value of the penalty parameter lambda at which predictions are cross-validated. Default is the entire sequence used to create the model, but importance scores will only be generated if a single lambda value is given.

offset

Numeric offset vector that was used to train elastic net.

weights

= Numeric offset vector that was used to train elastic net.

envir

Environment in which to look for variables listed in

metric

Character string giving prediction metric on which to base model selection. Can be one of "auc" for area under the (ROC) curve (only available for binomial family), "mae" for mean absolute error (not available for binomial family), "mae" for mean absolute error, "mce" for mean cross entropy, or "mse" for mean squared error. Default is "auto" which plots MSE for Gaussian-family models and MCE for all other families.

oneSE

Logical indicating whether or not to use the "one standard error" rule. If TRUE (default) the simplest model within one standard error of the optimal model is returned. If FALSE the model with the optimal cross-validation performance is returned.

parallel_type

(Optional) character string indicating the type of parallel operation to be used, either "fork" or "sock". If omitted and n_cores > 1, the default is "sock" for Windows and otherwise either "fork" or "sock" depending on which process is being run.

n_cores

Integer value indicating the number of workers to run in parallel during subset search and cross-validation. By default, this will be set to one fewer than the maximum number of physical cores you have available, as indicated by detectCores. Set to 1 to disable parallel processing.

cl

(Optional) parallel or snow cluster for use if parallel_type = "sock". If not supplied, a cluster on the local machine is automatically created.

Details

To obtain cross-validation statistics, first fit a model as you normally would and then pass the model object to validate.

Value

a "cross_valid" object consisting of a list with the following elements:

cv_stats

a list of cross-validated prediction metrics, each containing the mean, between-fold standard error ("btwn_fold_se"), and between-repetition min-max range ("btwn_rep_range") of each metric.

predictions

a data frame containing the hold-out predictions for each row in the training data, with a separate column for each repetition of the k-fold cross-validation

fold_assignments

a data frame of equal dimensions to predictions giving the number of the hold-out fold of the corresponding element in predictions

parameters

a list documenting how many folds and repetitions were used for cross-validation, and the seed passed to the random number generator, which will be needed to reproduce the random fold assignments

Methods (by class)

  • validate(lm): Cross-validation of linear models

  • validate(glm): Cross-validation of GLMs

  • validate(zeroinfl): Cross-validation of GLMs

  • validate(glmnet): Cross-validation of GLM nets

  • validate(beset): Cross-validation of beset objects

  • validate(nested): Extract test error estimates from "nested beset" objects with nested cross-validation

  • validate(randomForest): Cross-validation of random forests


jashu/beset documentation built on April 20, 2023, 5:28 a.m.