View source: R/cross_validate.R
cross_validate | R Documentation |
Cross-validate one or multiple linear or logistic regression
models at once. Perform repeated cross-validation.
Returns results in a tibble
for easy comparison,
reporting and further analysis.
See cross_validate_fn()
for use
with custom model functions.
cross_validate(
data,
formulas,
family,
fold_cols = ".folds",
control = NULL,
REML = FALSE,
cutoff = 0.5,
positive = 2,
metrics = list(),
preprocessing = NULL,
rm_nc = FALSE,
parallel = FALSE,
verbose = FALSE,
link = deprecated(),
models = deprecated(),
model_verbose = deprecated()
)
data |
Must include one or more grouping factors for identifying folds
- as made with | |||||||||||
formulas |
Model formulas as strings. (Character) E.g. Can contain random effects. E.g. | |||||||||||
family |
Name of the family. (Character) Currently supports See | |||||||||||
fold_cols |
Name(s) of grouping factor(s) for identifying folds. (Character) Include names of multiple grouping factors for repeated cross-validation. | |||||||||||
control |
Construct control structures for mixed model fitting
(with N.B. Ignored if fitting | |||||||||||
REML |
Restricted Maximum Likelihood. (Logical) | |||||||||||
cutoff |
Threshold for predicted classes. (Numeric) N.B. Binomial models only | |||||||||||
positive |
Level from dependent variable to predict.
Either as character (preferable) or level index ( E.g. if we have the levels Note: For reproducibility, it's preferable to specify the name directly, as
different Used when calculating confusion matrix metrics and creating The N.B. Only affects evaluation metrics, not the model training or returned predictions. N.B. Binomial models only. | |||||||||||
metrics |
E.g. You can enable/disable all metrics at once by including
The Also accepts the string | |||||||||||
preprocessing |
Name of preprocessing to apply. Available preprocessings are:
The preprocessing parameters ( N.B. The preprocessings should not affect the results
to a noticeable degree, although | |||||||||||
rm_nc |
Remove non-converged models from output. (Logical) | |||||||||||
parallel |
Whether to cross-validate the Remember to register a parallel backend first.
E.g. with | |||||||||||
verbose |
Whether to message process information like the number of model instances to fit and which model function was applied. (Logical) | |||||||||||
link , models , model_verbose |
Deprecated. |
Packages used:
Gaussian: stats::lm
, lme4::lmer
Binomial: stats::glm
, lme4::glmer
AIC
: stats::AIC
AICc
: MuMIn::AICc
BIC
: stats::BIC
r2m
: MuMIn::r.squaredGLMM
r2c
: MuMIn::r.squaredGLMM
ROC and AUC
: pROC::roc
tibble
with results for each model.
A nested tibble
with coefficients of the models from all iterations.
Number of total folds.
Number of fold columns.
Count of convergence warnings. Consider discarding models that did not converge on all iterations. Note: you might still see results, but these should be taken with a grain of salt!
Count of other warnings. These are warnings without keywords such as "convergence".
Count of Singular Fit messages.
See lme4::isSingular
for more information.
Nested tibble
with the warnings and messages caught for each model.
A nested Process information object with information about the evaluation.
Name of dependent variable.
Names of fixed effects.
Names of random effects, if any.
Nested tibble
with preprocessing parameters, if any.
—————————————————————-
—————————————————————-
Average RMSE
, MAE
, NRMSE(IQR)
,
RRSE
, RAE
, RMSLE
,
AIC
, AICc
,
and BIC
of all the iterations*,
omitting potential NAs from non-converged iterations.
Note that the Information Criterion metrics (AIC
, AICc
, and BIC
) are also averages.
See the additional metrics (disabled by default) at ?gaussian_metrics
.
A nested tibble
with the predictions and targets.
A nested tibble
with the non-averaged results from all iterations.
* In repeated cross-validation, the metrics are first averaged for each fold column (repetition) and then averaged again.
—————————————————————-
—————————————————————-
Based on the collected predictions from the test folds*,
a confusion matrix and a ROC
curve are created to get the following:
ROC
:
AUC
, Lower CI
, and Upper CI
Confusion Matrix
:
Balanced Accuracy
,
F1
,
Sensitivity
,
Specificity
,
Positive Predictive Value
,
Negative Predictive Value
,
Kappa
,
Detection Rate
,
Detection Prevalence
,
Prevalence
, and
MCC
(Matthews correlation coefficient).
See the additional metrics (disabled by default) at
?binomial_metrics
.
Also includes:
A nested tibble
with predictions, predicted classes (depends on cutoff
), and the targets.
Note, that the predictions are not necessarily of the specified positive
class, but of
the model's positive class (second level of dependent variable, alphabetically).
The pROC::roc
ROC
curve object(s).
A nested tibble
with the confusion matrix/matrices.
The Pos_
columns tells you whether a row is a
True Positive (TP
), True Negative (TN
),
False Positive (FP
), or False Negative (FN
),
depending on which level is the "positive" class. I.e. the level you wish to predict.
A nested tibble
with the results from all fold columns.
The name of the Positive Class.
* In repeated cross-validation, an evaluation is made per fold column (repetition) and averaged.
Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk
Benjamin Hugh Zachariae
Other validation functions:
cross_validate_fn()
,
validate()
,
validate_fn()
# Attach packages
library(cvms)
library(groupdata2) # fold()
library(dplyr) # %>% arrange()
# Data is part of cvms
data <- participant.scores
# Set seed for reproducibility
set.seed(7)
# Fold data
data <- fold(
data,
k = 4,
cat_col = "diagnosis",
id_col = "participant"
) %>%
arrange(.folds)
#
# Cross-validate a single model
#
# Gaussian
cross_validate(
data,
formulas = "score~diagnosis",
family = "gaussian",
REML = FALSE
)
# Binomial
cross_validate(
data,
formulas = "diagnosis~score",
family = "binomial"
)
#
# Cross-validate multiple models
#
formulas <- c(
"score~diagnosis+(1|session)",
"score~age+(1|session)"
)
cross_validate(
data,
formulas = formulas,
family = "gaussian",
REML = FALSE
)
#
# Use parallelization
#
# Attach doParallel and register four cores
# Uncomment:
# library(doParallel)
# registerDoParallel(4)
# Cross-validate a list of model formulas in parallel
# Make sure to uncomment the parallel argument
cross_validate(
data,
formulas = formulas,
family = "gaussian"
#, parallel = TRUE # Uncomment
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.