cross_validate_fn: Cross-validate custom model functions for model selection

View source: R/cross_validate_fn.R

cross_validate_fnR Documentation

Cross-validate custom model functions for model selection

Description

\Sexpr[results=rd, stage=render]{lifecycle::badge("experimental")}

Cross-validate your model function with one or multiple model formulas at once. Perform repeated cross-validation. Preprocess the train/test split within the cross-validation. Perform hyperparameter tuning with grid search. Returns results in a tibble for easy comparison, reporting and further analysis.

Compared to cross_validate(), this function allows you supply a custom model function, a predict function, a preprocess function and the hyperparameter values to cross-validate.

Supports regression and classification (binary and multiclass). See `type`.

Note that some metrics may not be computable for some types of model objects.

Usage

cross_validate_fn(
  data,
  formulas,
  type,
  model_fn,
  predict_fn,
  preprocess_fn = NULL,
  preprocess_once = FALSE,
  hyperparameters = NULL,
  fold_cols = ".folds",
  cutoff = 0.5,
  positive = 2,
  metrics = list(),
  rm_nc = FALSE,
  parallel = FALSE,
  verbose = TRUE
)

Arguments

data

data.frame.

Must include one or more grouping factors for identifying folds - as made with groupdata2::fold().

formulas

Model formulas as strings. (Character)

Will be converted to formula objects before being passed to `model_fn`.

E.g. c("y~x", "y~z").

Can contain random effects.

E.g. c("y~x+(1|r)", "y~z+(1|r)").

type

Type of evaluation to perform:

"gaussian" for regression (like linear regression).

"binomial" for binary classification.

"multinomial" for multiclass classification.

model_fn

Model function that returns a fitted model object. Will usually wrap an existing model function like e1071::svm or nnet::multinom.

Must have the following function arguments:

function(train_data, formula,

⁠ ⁠hyperparameters)

predict_fn

Function for predicting the targets in the test folds/sets using the fitted model object. Will usually wrap stats::predict(), but doesn't have to.

Must have the following function arguments:

function(test_data, model, formula,

⁠ ⁠hyperparameters, train_data)

Must return predictions in the following formats, depending on `type`:

Binomial

vector or one-column matrix / data.frame with probabilities (0-1) of the second class, alphabetically. E.g.:

c(0.3, 0.5, 0.1, 0.5)

N.B. When unsure whether a model type produces probabilities based off the alphabetic order of your classes, using 0 and 1 as classes in the dependent variable instead of the class names should increase the chance of getting probabilities of the right class.

Gaussian

vector or one-column matrix / data.frame with the predicted value. E.g.:

c(3.7, 0.9, 1.2, 7.3)

Multinomial

data.frame with one column per class containing probabilities of the class. Column names should be identical to how the class names are written in the target column. E.g.:

class_1 class_2 class_3
0.269 0.528 0.203
0.368 0.322 0.310
0.375 0.371 0.254
... ... ...
preprocess_fn

Function for preprocessing the training and test sets.

Can, for instance, be used to standardize both the training and test sets with the scaling and centering parameters from the training set.

Must have the following function arguments:

function(train_data, test_data,

⁠ ⁠formula, hyperparameters)

Must return a list with the preprocessed `train_data` and `test_data`. It may also contain a tibble with the parameters used in preprocessing:

list("train" = train_data,

⁠ ⁠"test" = test_data,

⁠ ⁠"parameters" = preprocess_parameters)

Additional elements in the returned list will be ignored.

The optional parameters tibble will be included in the output. It could have the following format:

Measure var_1 var_2
Mean 37.921 88.231
SD 12.4 5.986
... ... ...

N.B. When `preprocess_once` is FALSE, the current formula and hyperparameters will be provided. Otherwise, these arguments will be NULL.

preprocess_once

Whether to apply the preprocessing once (ignoring the formula and hyperparameters arguments in `preprocess_fn`) or for every model separately. (Logical)

When preprocessing does not depend on the current formula or hyperparameters, we can do the preprocessing of each train/test split once, to save time. This may require holding a lot more data in memory though, why it is not the default setting.

hyperparameters

Either a named list with hyperparameter values to combine in a grid or a data.frame with one row per hyperparameter combination.

Named list for grid search

Add ".n" to sample the combinations. Can be the number of combinations to use, or a percentage between 0 and 1.

E.g.

list(".n" = 10, # sample 10 combinations

⁠ ⁠"lrn_rate" = c(0.1, 0.01, 0.001),

⁠ ⁠"h_layers" = c(10, 100, 1000),

⁠ ⁠"drop_out" = runif(5, 0.3, 0.7))

data.frame with specific hyperparameter combinations

One row per combination to test.

E.g.

lrn_rate h_layers drop_out
0.1 10 0.65
0.1 1000 0.65
0.01 1000 0.63
... ... ...
fold_cols

Name(s) of grouping factor(s) for identifying folds. (Character)

Include names of multiple grouping factors for repeated cross-validation.

cutoff

Threshold for predicted classes. (Numeric)

N.B. Binomial models only

positive

Level from dependent variable to predict. Either as character (preferable) or level index (1 or 2 - alphabetically).

E.g. if we have the levels "cat" and "dog" and we want "dog" to be the positive class, we can either provide "dog" or 2, as alphabetically, "dog" comes after "cat".

Note: For reproducibility, it's preferable to specify the name directly, as different locales may sort the levels differently.

Used when calculating confusion matrix metrics and creating ROC curves.

The Process column in the output can be used to verify this setting.

N.B. Only affects evaluation metrics, not the model training or returned predictions.

N.B. Binomial models only.

metrics

list for enabling/disabling metrics.

E.g. list("RMSE" = FALSE) would remove RMSE from the regression results, and list("Accuracy" = TRUE) would add the regular Accuracy metric to the classification results. Default values (TRUE/FALSE) will be used for the remaining available metrics.

You can enable/disable all metrics at once by including "all" = TRUE/FALSE in the list. This is done prior to enabling/disabling individual metrics, why f.i. list("all" = FALSE, "RMSE" = TRUE) would return only the RMSE metric.

The list can be created with gaussian_metrics(), binomial_metrics(), or multinomial_metrics().

Also accepts the string "all".

rm_nc

Remove non-converged models from output. (Logical)

parallel

Whether to cross-validate the list of models in parallel. (Logical)

Remember to register a parallel backend first. E.g. with doParallel::registerDoParallel.

verbose

Whether to message process information like the number of model instances to fit. (Logical)

Details

Packages used:

Results

Shared

AIC : stats::AIC

AICc : MuMIn::AICc

BIC : stats::BIC

Gaussian

r2m : MuMIn::r.squaredGLMM

r2c : MuMIn::r.squaredGLMM

Binomial and Multinomial

ROC and related metrics:

Binomial: pROC::roc

Multinomial: pROC::multiclass.roc

Value

tibble with results for each model.

N.B. The Fold column in the nested tibbles contains the test fold in that train/test split.

Shared across families

A nested tibble with coefficients of the models from all iterations. The coefficients are extracted from the model object with parameters::model_parameters() or coef() (with some restrictions on the output). If these attempts fail, a default coefficients tibble filled with NAs is returned.

Nested tibble with the used preprocessing parameters, if a passed preprocess_fn returns the parameters in a tibble.

Number of total folds.

Number of fold columns.

Count of convergence warnings, using a limited set of keywords (e.g. "convergence"). If a convergence warning does not contain one of these keywords, it will be counted with other warnings. Consider discarding models that did not converge on all iterations. Note: you might still see results, but these should be taken with a grain of salt!

Nested tibble with the warnings and messages caught for each model.

A nested Process information object with information about the evaluation.

Name of dependent variable.

Names of fixed effects.

Names of random effects, if any.

—————————————————————-

Gaussian Results

—————————————————————-

Average RMSE, MAE, NRMSE(IQR), RRSE, RAE, RMSLE of all the iterations*, omitting potential NAs from non-converged iterations.

See the additional metrics (disabled by default) at ?gaussian_metrics.

A nested tibble with the predictions and targets.

A nested tibble with the non-averaged results from all iterations.

* In repeated cross-validation, the metrics are first averaged for each fold column (repetition) and then averaged again.

—————————————————————-

Binomial Results

—————————————————————-

Based on the collected predictions from the test folds*, a confusion matrix and a ROC curve are created to get the following:

ROC:

AUC, Lower CI, and Upper CI

Confusion Matrix:

Balanced Accuracy, F1, Sensitivity, Specificity, Positive Predictive Value, Negative Predictive Value, Kappa, Detection Rate, Detection Prevalence, Prevalence, and MCC (Matthews correlation coefficient).

See the additional metrics (disabled by default) at ?binomial_metrics.

Also includes:

A nested tibble with predictions, predicted classes (depends on cutoff), and the targets. Note, that the predictions are not necessarily of the specified positive class, but of the model's positive class (second level of dependent variable, alphabetically).

The pROC::roc ROC curve object(s).

A nested tibble with the confusion matrix/matrices. The Pos_ columns tells you whether a row is a True Positive (TP), True Negative (TN), False Positive (FP), or False Negative (FN), depending on which level is the "positive" class. I.e. the level you wish to predict.

A nested tibble with the results from all fold columns.

The name of the Positive Class.

* In repeated cross-validation, an evaluation is made per fold column (repetition) and averaged.

—————————————————————-

Multinomial Results

—————————————————————-

For each class, a one-vs-all binomial evaluation is performed. This creates a Class Level Results tibble containing the same metrics as the binomial results described above (excluding MCC, AUC, Lower CI and Upper CI), along with a count of the class in the target column (Support). These metrics are used to calculate the macro metrics. The nested class level results tibble is also included in the output tibble, and could be reported along with the macro and overall metrics.

The output tibble contains the macro and overall metrics. The metrics that share their name with the metrics in the nested class level results tibble are averages of those metrics (note: does not remove NAs before averaging). In addition to these, it also includes the Overall Accuracy and the multiclass MCC.

Other available metrics (disabled by default, see metrics): Accuracy, multiclass AUC, Weighted Balanced Accuracy, Weighted Accuracy, Weighted F1, Weighted Sensitivity, Weighted Sensitivity, Weighted Specificity, Weighted Pos Pred Value, Weighted Neg Pred Value, Weighted Kappa, Weighted Detection Rate, Weighted Detection Prevalence, and Weighted Prevalence.

Note that the "Weighted" average metrics are weighted by the Support.

Also includes:

A nested tibble with the predictions, predicted classes, and targets.

A list of ROC curve objects when AUC is enabled.

A nested tibble with the multiclass Confusion Matrix.

Class Level Results

Besides the binomial evaluation metrics and the Support, the nested class level results tibble also contains a nested tibble with the Confusion Matrix from the one-vs-all evaluation. The Pos_ columns tells you whether a row is a True Positive (TP), True Negative (TN), False Positive (FP), or False Negative (FN), depending on which level is the "positive" class. In our case, 1 is the current class and 0 represents all the other classes together.

Author(s)

Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk

See Also

Other validation functions: cross_validate(), validate_fn(), validate()

Examples


# Attach packages
library(cvms)
library(groupdata2) # fold()
library(dplyr) # %>% arrange() mutate()

# Note: More examples of custom functions can be found at:
# model_fn: model_functions()
# predict_fn: predict_functions()
# preprocess_fn: preprocess_functions()

# Data is part of cvms
data <- participant.scores

# Set seed for reproducibility
set.seed(7)

# Fold data
data <- fold(
  data,
  k = 4,
  cat_col = "diagnosis",
  id_col = "participant"
) %>%
  mutate(diagnosis = as.factor(diagnosis)) %>%
  arrange(.folds)

# Cross-validate multiple formulas

formulas_gaussian <- c(
  "score ~ diagnosis",
  "score ~ age"
)
formulas_binomial <- c(
  "diagnosis ~ score",
  "diagnosis ~ age"
)

#
# Gaussian
#

# Create model function that returns a fitted model object
lm_model_fn <- function(train_data, formula, hyperparameters) {
  lm(formula = formula, data = train_data)
}

# Create predict function that returns the predictions
lm_predict_fn <- function(test_data, model, formula,
                          hyperparameters, train_data) {
  stats::predict(
    object = model,
    newdata = test_data,
    type = "response",
    allow.new.levels = TRUE
  )
}

# Cross-validate the model function
cross_validate_fn(
  data,
  formulas = formulas_gaussian,
  type = "gaussian",
  model_fn = lm_model_fn,
  predict_fn = lm_predict_fn,
  fold_cols = ".folds"
)

#
# Binomial
#

# Create model function that returns a fitted model object
glm_model_fn <- function(train_data, formula, hyperparameters) {
  glm(formula = formula, data = train_data, family = "binomial")
}

# Create predict function that returns the predictions
glm_predict_fn <- function(test_data, model, formula,
                           hyperparameters, train_data) {
  stats::predict(
    object = model,
    newdata = test_data,
    type = "response",
    allow.new.levels = TRUE
  )
}

# Cross-validate the model function
cross_validate_fn(
  data,
  formulas = formulas_binomial,
  type = "binomial",
  model_fn = glm_model_fn,
  predict_fn = glm_predict_fn,
  fold_cols = ".folds"
)

#
# Support Vector Machine (svm)
# with hyperparameter tuning
#

# Only run if the `e1071` package is installed
if (requireNamespace("e1071", quietly = TRUE)){

# Create model function that returns a fitted model object
# We use the hyperparameters arg to pass in the kernel and cost values
svm_model_fn <- function(train_data, formula, hyperparameters) {

  # Expected hyperparameters:
  #  - kernel
  #  - cost
  if (!"kernel" %in% names(hyperparameters))
    stop("'hyperparameters' must include 'kernel'")
  if (!"cost" %in% names(hyperparameters))
    stop("'hyperparameters' must include 'cost'")

  e1071::svm(
    formula = formula,
    data = train_data,
    kernel = hyperparameters[["kernel"]],
    cost = hyperparameters[["cost"]],
    scale = FALSE,
    type = "C-classification",
    probability = TRUE
  )
}

# Create predict function that returns the predictions
svm_predict_fn <- function(test_data, model, formula,
                           hyperparameters, train_data) {
  predictions <- stats::predict(
    object = model,
    newdata = test_data,
    allow.new.levels = TRUE,
    probability = TRUE
  )

  # Extract probabilities
  probabilities <- dplyr::as_tibble(
    attr(predictions, "probabilities")
  )

  # Return second column
  probabilities[[2]]
}

# Specify hyperparameters to try
# The optional ".n" samples 4 combinations
svm_hparams <- list(
  ".n" = 4,
  "kernel" = c("linear", "radial"),
  "cost" = c(1, 5, 10)
)

# Cross-validate the model function
cv <- cross_validate_fn(
  data,
  formulas = formulas_binomial,
  type = "binomial",
  model_fn = svm_model_fn,
  predict_fn = svm_predict_fn,
  hyperparameters = svm_hparams,
  fold_cols = ".folds"
)

cv

# The `HParams` column has the nested hyperparameter values
cv %>%
  select(Dependent, Fixed, HParams, `Balanced Accuracy`, F1, AUC, MCC) %>%
  tidyr::unnest(cols = "HParams") %>%
  arrange(desc(`Balanced Accuracy`), desc(F1))

#
# Use parallelization
# The below examples show the speed gains when running in parallel
#

# Attach doParallel and register four cores
# Uncomment:
# library(doParallel)
# registerDoParallel(4)

# Specify hyperparameters such that we will
# cross-validate 20 models
hparams <- list(
  "kernel" = c("linear", "radial"),
  "cost" = 1:5
)

# Cross-validate a list of 20 models in parallel
# Make sure to uncomment the parallel argument
system.time({
  cross_validate_fn(
    data,
    formulas = formulas_gaussian,
    type = "gaussian",
    model_fn = svm_model_fn,
    predict_fn = svm_predict_fn,
    hyperparameters = hparams,
    fold_cols = ".folds"
    #, parallel = TRUE  # Uncomment
  )
})

# Cross-validate a list of 20 models sequentially
system.time({
  cross_validate_fn(
    data,
    formulas = formulas_gaussian,
    type = "gaussian",
    model_fn = svm_model_fn,
    predict_fn = svm_predict_fn,
    hyperparameters = hparams,
    fold_cols = ".folds"
    #, parallel = TRUE  # Uncomment
  )
})

} # closes `e1071` package check


cvms documentation built on July 9, 2023, 6:56 p.m.