Description Usage Arguments Details Value Author(s) See Also Examples
View source: R/cross_validate_fn.R
Crossvalidate your model function with one or multiple model formulas at once.
Perform repeated crossvalidation. Preprocess the train/test split
within the crossvalidation. Perform hyperparameter tuning with grid search.
Returns results in a tibble
for easy comparison,
reporting and further analysis.
Compared to cross_validate()
,
this function allows you supply a custom model function, a predict function,
a preprocess function and the hyperparameter values to crossvalidate.
Supports regression and classification (binary and multiclass).
See `type`
.
Note that some metrics may not be computable for some types of model objects.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 
data 
Must include one or more grouping factors for identifying folds
 as made with  
formulas 
Model formulas as strings. (Character) Will be converted to E.g. Can contain random effects. E.g.  
type 
Type of evaluation to perform:
 
model_fn 
Model function that returns a fitted model object.
Will usually wrap an existing model function like Must have the following function arguments:
 
predict_fn 
Function for predicting the targets in the test folds/sets using the fitted model object.
Will usually wrap Must have the following function arguments:
Must return predictions in the following formats, depending on Binomial
Gaussian
Multinomial
 
preprocess_fn 
Function for preprocessing the training and test sets. Can, for instance, be used to standardize both the training and test sets with the scaling and centering parameters from the training set. Must have the following function arguments:
Must return a
Additional elements in the returned The optional parameters
N.B. When  
preprocess_once 
Whether to apply the preprocessing once
(ignoring the formula and hyperparameters arguments in When preprocessing does not depend on the current formula or hyperparameters, we can do the preprocessing of each train/test split once, to save time. This may require holding a lot more data in memory though, why it is not the default setting.  
hyperparameters 
Either a Named list for grid searchAdd E.g.

lrn_rate  h_layers  drop_out 
0.1  10  0.65 
0.1  1000  0.65 
0.01  1000  0.63 
...  ...  ... 
fold_cols
Name(s) of grouping factor(s) for identifying folds. (Character)
Include names of multiple grouping factors for repeated crossvalidation.
cutoff
Threshold for predicted classes. (Numeric)
N.B. Binomial models only
positive
Level from dependent variable to predict.
Either as character (preferable) or level index (1
or 2
 alphabetically).
E.g. if we have the levels "cat"
and "dog"
and we want "dog"
to be the positive class,
we can either provide "dog"
or 2
, as alphabetically, "dog"
comes after "cat"
.
Note: For reproducibility, it's preferable to specify the name directly, as
different locales
may sort the levels differently.
Used when calculating confusion matrix metrics and creating ROC
curves.
The Positive Class
column in the output can be used to verify this setting.
N.B. Only affects evaluation metrics, not the model training or returned predictions.
N.B. Binomial models only.
metrics
list
for enabling/disabling metrics.
E.g. list("RMSE" = FALSE)
would remove RMSE
from the regression results,
and list("Accuracy" = TRUE)
would add the regular Accuracy
metric
to the classification results.
Default values (TRUE
/FALSE
) will be used for the remaining available metrics.
You can enable/disable all metrics at once by including
"all" = TRUE/FALSE
in the list
. This is done prior to enabling/disabling
individual metrics, why f.i. list("all" = FALSE, "RMSE" = TRUE)
would return only the RMSE
metric.
The list
can be created with
gaussian_metrics()
,
binomial_metrics()
, or
multinomial_metrics()
.
Also accepts the string "all"
.
rm_nc
Remove nonconverged models from output. (Logical)
parallel
Whether to crossvalidate the list
of models in parallel. (Logical)
Remember to register a parallel backend first.
E.g. with doParallel::registerDoParallel
.
verbose
Whether to message process information like the number of model instances to fit. (Logical)
Packages used:
AIC : stats::AIC
AICc : MuMIn::AICc
BIC : stats::BIC
r2m : MuMIn::r.squaredGLMM
r2c : MuMIn::r.squaredGLMM
ROC and related metrics:
Binomial: pROC::roc
Multinomial: pROC::multiclass.roc
tibble
with results for each model.
N.B. The Fold column in the nested tibble
s contains the test fold in that train/test split.
A nested tibble
with coefficients of the models from all iterations. The coefficients
are extracted from the model object with parameters::model_parameters()
or
coef()
(with some restrictions on the output).
If these attempts fail, a default coefficients tibble
filled with NA
s is returned.
Nested tibble
with the used preprocessing parameters,
if a passed preprocess_fn
returns the parameters in a tibble
.
Number of total folds.
Number of fold columns.
Count of convergence warnings, using a limited set of keywords (e.g. "convergence"). If a convergence warning does not contain one of these keywords, it will be counted with other warnings. Consider discarding models that did not converge on all iterations. Note: you might still see results, but these should be taken with a grain of salt!
Nested tibble
with the warnings and messages caught for each model.
A nested Process information object with information about the evaluation.
Name of dependent variable.
Names of fixed effects.
Names of random effects, if any.
—————————————————————
—————————————————————
Average RMSE
, MAE
, NRMSE(IQR)
,
RRSE
, RAE
, RMSLE
of all the iterations*,
omitting potential NAs from nonconverged iterations.
See the additional metrics (disabled by default) at ?gaussian_metrics
.
A nested tibble
with the predictions and targets.
A nested tibble
with the nonaveraged results from all iterations.
* In repeated crossvalidation, the metrics are first averaged for each fold column (repetition) and then averaged again.
—————————————————————
—————————————————————
Based on the collected predictions from the test folds*,
a confusion matrix and a ROC
curve are created to get the following:
ROC
:
AUC
, Lower CI
, and Upper CI
Confusion Matrix
:
Balanced Accuracy
,
F1
,
Sensitivity
,
Specificity
,
Positive Predictive Value
,
Negative Predictive Value
,
Kappa
,
Detection Rate
,
Detection Prevalence
,
Prevalence
, and
MCC
(Matthews correlation coefficient).
See the additional metrics (disabled by default) at
?binomial_metrics
.
Also includes:
A nested tibble
with predictions, predicted classes (depends on cutoff
), and the targets.
Note, that the predictions are not necessarily of the specified positive
class, but of
the model's positive class (second level of dependent variable, alphabetically).
The pROC::roc
ROC
curve object(s).
A nested tibble
with the confusion matrix/matrices.
The Pos_
columns tells you whether a row is a
True Positive (TP
), True Negative (TN
),
False Positive (FP
), or False Negative (FN
),
depending on which level is the "positive" class. I.e. the level you wish to predict.
A nested tibble
with the results from all fold columns.
The name of the Positive Class.
* In repeated crossvalidation, an evaluation is made per fold column (repetition) and averaged.
—————————————————————
—————————————————————
For each class, a onevsall binomial evaluation is performed. This creates
a Class Level Results tibble
containing the same metrics as the binomial results
described above (excluding MCC
, AUC
, Lower CI
and Upper CI
),
along with a count of the class in the target column (Support
).
These metrics are used to calculate the macro metrics. The nested class level results
tibble
is also included in the output tibble
,
and could be reported along with the macro and overall metrics.
The output tibble
contains the macro and overall metrics.
The metrics that share their name with the metrics in the nested
class level results tibble
are averages of those metrics
(note: does not remove NA
s before averaging).
In addition to these, it also includes the Overall Accuracy
and
the multiclass MCC
.
Other available metrics (disabled by default, see metrics
):
Accuracy
,
multiclass AUC
,
Weighted Balanced Accuracy
,
Weighted Accuracy
,
Weighted F1
,
Weighted Sensitivity
,
Weighted Sensitivity
,
Weighted Specificity
,
Weighted Pos Pred Value
,
Weighted Neg Pred Value
,
Weighted Kappa
,
Weighted Detection Rate
,
Weighted Detection Prevalence
, and
Weighted Prevalence
.
Note that the "Weighted" average metrics are weighted by the Support
.
Also includes:
A nested tibble
with the predictions, predicted classes, and targets.
A list
of ROC curve objects when AUC
is enabled.
A nested tibble
with the multiclass Confusion Matrix.
Class Level Results
Besides the binomial evaluation metrics and the Support
,
the nested class level results tibble
also contains a
nested tibble
with the Confusion Matrix from the onevsall evaluation.
The Pos_
columns tells you whether a row is a
True Positive (TP
), True Negative (TN
),
False Positive (FP
), or False Negative (FN
),
depending on which level is the "positive" class. In our case, 1
is the current class
and 0
represents all the other classes together.
Ludvig Renbo Olsen, rpkgs@ludvigolsen.dk
Other validation functions:
cross_validate()
,
validate_fn()
,
validate()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217  # Attach packages
library(cvms)
library(groupdata2) # fold()
library(dplyr) # %>% arrange() mutate()
# Note: More examples of custom functions can be found at:
# model_fn: model_functions()
# predict_fn: predict_functions()
# preprocess_fn: preprocess_functions()
# Data is part of cvms
data < participant.scores
# Set seed for reproducibility
set.seed(7)
# Fold data
data < fold(
data,
k = 4,
cat_col = "diagnosis",
id_col = "participant"
) %>%
mutate(diagnosis = as.factor(diagnosis)) %>%
arrange(.folds)
# Crossvalidate multiple formulas
formulas_gaussian < c(
"score ~ diagnosis",
"score ~ age"
)
formulas_binomial < c(
"diagnosis ~ score",
"diagnosis ~ age"
)
#
# Gaussian
#
# Create model function that returns a fitted model object
lm_model_fn < function(train_data, formula, hyperparameters) {
lm(formula = formula, data = train_data)
}
# Create predict function that returns the predictions
lm_predict_fn < function(test_data, model, formula,
hyperparameters, train_data) {
stats::predict(
object = model,
newdata = test_data,
type = "response",
allow.new.levels = TRUE
)
}
# Crossvalidate the model function
cross_validate_fn(
data,
formulas = formulas_gaussian,
type = "gaussian",
model_fn = lm_model_fn,
predict_fn = lm_predict_fn,
fold_cols = ".folds"
)
#
# Binomial
#
# Create model function that returns a fitted model object
glm_model_fn < function(train_data, formula, hyperparameters) {
glm(formula = formula, data = train_data, family = "binomial")
}
# Create predict function that returns the predictions
glm_predict_fn < function(test_data, model, formula,
hyperparameters, train_data) {
stats::predict(
object = model,
newdata = test_data,
type = "response",
allow.new.levels = TRUE
)
}
# Crossvalidate the model function
cross_validate_fn(
data,
formulas = formulas_binomial,
type = "binomial",
model_fn = glm_model_fn,
predict_fn = glm_predict_fn,
fold_cols = ".folds"
)
#
# Support Vector Machine (svm)
# with hyperparameter tuning
#
# Create model function that returns a fitted model object
# We use the hyperparameters arg to pass in the kernel and cost values
svm_model_fn < function(train_data, formula, hyperparameters) {
# Expected hyperparameters:
#  kernel
#  cost
if (!"kernel" %in% names(hyperparameters))
stop("'hyperparameters' must include 'kernel'")
if (!"cost" %in% names(hyperparameters))
stop("'hyperparameters' must include 'cost'")
e1071::svm(
formula = formula,
data = train_data,
kernel = hyperparameters[["kernel"]],
cost = hyperparameters[["cost"]],
scale = FALSE,
type = "Cclassification",
probability = TRUE
)
}
# Create predict function that returns the predictions
svm_predict_fn < function(test_data, model, formula,
hyperparameters, train_data) {
predictions < stats::predict(
object = model,
newdata = test_data,
allow.new.levels = TRUE,
probability = TRUE
)
# Extract probabilities
probabilities < dplyr::as_tibble(
attr(predictions, "probabilities")
)
# Return second column
probabilities[[2]]
}
# Specify hyperparameters to try
# The optional ".n" samples 4 combinations
svm_hparams < list(
".n" = 4,
"kernel" = c("linear", "radial"),
"cost" = c(1, 5, 10)
)
# Crossvalidate the model function
cv < cross_validate_fn(
data,
formulas = formulas_binomial,
type = "binomial",
model_fn = svm_model_fn,
predict_fn = svm_predict_fn,
hyperparameters = svm_hparams,
fold_cols = ".folds"
)
cv
# The `HParams` column has the nested hyperparameter values
cv %>%
select(Dependent, Fixed, HParams, `Balanced Accuracy`, F1, AUC, MCC) %>%
tidyr::unnest(cols = "HParams") %>%
arrange(desc(`Balanced Accuracy`), desc(F1))
#
# Use parallelization
# The below examples show the speed gains when running in parallel
#
# Attach doParallel and register four cores
# Uncomment:
# library(doParallel)
# registerDoParallel(4)
# Specify hyperparameters such that we will
# crossvalidate 20 models
hparams < list(
"kernel" = c("linear", "radial"),
"cost" = 1:5
)
# Crossvalidate a list of 20 models in parallel
# Make sure to uncomment the parallel argument
system.time({
cross_validate_fn(
data,
formulas = formulas_gaussian,
type = "gaussian",
model_fn = svm_model_fn,
predict_fn = svm_predict_fn,
hyperparameters = hparams,
fold_cols = ".folds"
#, parallel = TRUE # Uncomment
)
})
# Crossvalidate a list of 20 models sequentially
system.time({
cross_validate_fn(
data,
formulas = formulas_gaussian,
type = "gaussian",
model_fn = svm_model_fn,
predict_fn = svm_predict_fn,
hyperparameters = hparams,
fold_cols = ".folds"
#, parallel = TRUE # Uncomment
)
})

Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.