validate_fn | R Documentation |
Fit your model function on a training set and validate it by
predicting a test/validation set.
Validate different hyperparameter combinations and formulas at once.
Preprocess the train/test split.
Returns results and fitted models in a tibble
for easy reporting and further analysis.
Compared to validate()
,
this function allows you supply a custom model function, a predict function,
a preprocess function and the hyperparameter values to validate.
Supports regression and classification (binary and multiclass).
See `type`
.
Note that some metrics may not be computable for some types of model objects.
validate_fn(
train_data,
formulas,
type,
model_fn,
predict_fn,
test_data = NULL,
preprocess_fn = NULL,
preprocess_once = FALSE,
hyperparameters = NULL,
partitions_col = ".partitions",
cutoff = 0.5,
positive = 2,
metrics = list(),
rm_nc = FALSE,
parallel = FALSE,
verbose = TRUE
)
train_data |
Can contain a grouping factor for identifying partitions - as made with
| |||||||||||||||
formulas |
Model formulas as strings. (Character) Will be converted to E.g. Can contain random effects. E.g. | |||||||||||||||
type |
Type of evaluation to perform:
| |||||||||||||||
model_fn |
Model function that returns a fitted model object.
Will usually wrap an existing model function like Must have the following function arguments:
| |||||||||||||||
predict_fn |
Function for predicting the targets in the test folds/sets using the fitted model object.
Will usually wrap Must have the following function arguments:
Must return predictions in the following formats, depending on Binomial
N.B. When unsure whether a model type produces probabilities based off the alphabetic order of your classes, using 0 and 1 as classes in the dependent variable instead of the class names should increase the chance of getting probabilities of the right class. Gaussian
Multinomial
| |||||||||||||||
test_data |
| |||||||||||||||
preprocess_fn |
Function for preprocessing the training and test sets. Can, for instance, be used to standardize both the training and test sets with the scaling and centering parameters from the training set. Must have the following function arguments:
Must return a
Additional elements in the returned The optional parameters
N.B. When | |||||||||||||||
preprocess_once |
Whether to apply the preprocessing once
(ignoring the formula and hyperparameters arguments in When preprocessing does not depend on the current formula or hyperparameters, we can do the preprocessing of each train/test split once, to save time. This may require holding a lot more data in memory though, why it is not the default setting. | |||||||||||||||
hyperparameters |
Either a Named list for grid searchAdd E.g.
|
lrn_rate | h_layers | drop_out |
0.1 | 10 | 0.65 |
0.1 | 1000 | 0.65 |
0.01 | 1000 | 0.63 |
... | ... | ... |
partitions_col
Name of grouping factor for identifying partitions. (Character)
Rows with the value 1
in `partitions_col`
are used as training set and
rows with the value 2
are used as test set.
N.B. Only used if `test_data`
is NULL
.
cutoff
Threshold for predicted classes. (Numeric)
N.B. Binomial models only
positive
Level from dependent variable to predict.
Either as character (preferable) or level index (1
or 2
- alphabetically).
E.g. if we have the levels "cat"
and "dog"
and we want "dog"
to be the positive class,
we can either provide "dog"
or 2
, as alphabetically, "dog"
comes after "cat"
.
Note: For reproducibility, it's preferable to specify the name directly, as
different locales
may sort the levels differently.
Used when calculating confusion matrix metrics and creating ROC
curves.
The Process
column in the output can be used to verify this setting.
N.B. Only affects evaluation metrics, not the model training or returned predictions.
N.B. Binomial models only.
metrics
list
for enabling/disabling metrics.
E.g. list("RMSE" = FALSE)
would remove RMSE
from the regression results,
and list("Accuracy" = TRUE)
would add the regular Accuracy
metric
to the classification results.
Default values (TRUE
/FALSE
) will be used for the remaining available metrics.
You can enable/disable all metrics at once by including
"all" = TRUE/FALSE
in the list
. This is done prior to enabling/disabling
individual metrics, why f.i. list("all" = FALSE, "RMSE" = TRUE)
would return only the RMSE
metric.
The list
can be created with
gaussian_metrics()
,
binomial_metrics()
, or
multinomial_metrics()
.
Also accepts the string "all"
.
rm_nc
Remove non-converged models from output. (Logical)
parallel
Whether to cross-validate the list
of models in parallel. (Logical)
Remember to register a parallel backend first.
E.g. with doParallel::registerDoParallel
.
verbose
Whether to message process information like the number of model instances to fit. (Logical)
Packages used:
AIC : stats::AIC
AICc : MuMIn::AICc
BIC : stats::BIC
r2m : MuMIn::r.squaredGLMM
r2c : MuMIn::r.squaredGLMM
ROC and related metrics:
Binomial: pROC::roc
Multinomial: pROC::multiclass.roc
tibble
with the results and model objects.
A nested tibble
with coefficients of the models. The coefficients
are extracted from the model object with parameters::model_parameters()
or
coef()
(with some restrictions on the output).
If these attempts fail, a default coefficients tibble
filled with NA
s is returned.
Nested tibble
with the used preprocessing parameters,
if a passed `preprocess_fn`
returns the parameters in a tibble
.
Count of convergence warnings, using a limited set of keywords (e.g. "convergence"). If a convergence warning does not contain one of these keywords, it will be counted with other warnings. Consider discarding models that did not converge on all iterations. Note: you might still see results, but these should be taken with a grain of salt!
Nested tibble
with the warnings and messages caught for each model.
Specified family.
Nested model objects.
Name of dependent variable.
Names of fixed effects.
Names of random effects, if any.
—————————————————————-
—————————————————————-
RMSE
, MAE
, NRMSE(IQR)
,
RRSE
, RAE
, and RMSLE
.
See the additional metrics (disabled by default) at ?gaussian_metrics
.
A nested tibble
with the predictions and targets.
—————————————————————-
—————————————————————-
Based on predictions of the test set,
a confusion matrix and a ROC
curve are created to get the following:
ROC
:
AUC
, Lower CI
, and Upper CI
Confusion Matrix
:
Balanced Accuracy
,
F1
,
Sensitivity
,
Specificity
,
Positive Predictive Value
,
Negative Predictive Value
,
Kappa
,
Detection Rate
,
Detection Prevalence
,
Prevalence
, and
MCC
(Matthews correlation coefficient).
See the additional metrics (disabled by default) at
?binomial_metrics
.
Also includes:
A nested tibble
with predictions, predicted classes (depends on cutoff
), and the targets.
Note, that the predictions are not necessarily of the specified positive
class, but of
the model's positive class (second level of dependent variable, alphabetically).
The pROC::roc
ROC
curve object(s).
A nested tibble
with the confusion matrix/matrices.
The Pos_
columns tells you whether a row is a
True Positive (TP
), True Negative (TN
),
False Positive (FP
), or False Negative (FN
),
depending on which level is the "positive" class. I.e. the level you wish to predict.
The name of the Positive Class.
—————————————————————-
—————————————————————-
For each class, a one-vs-all binomial evaluation is performed. This creates
a Class Level Results tibble
containing the same metrics as the binomial results
described above (excluding MCC
, AUC
, Lower CI
and Upper CI
),
along with a count of the class in the target column (Support
).
These metrics are used to calculate the macro-averaged metrics.
The nested class level results tibble
is also included in the output tibble
,
and could be reported along with the macro and overall metrics.
The output tibble
contains the macro and overall metrics.
The metrics that share their name with the metrics in the nested
class level results tibble
are averages of those metrics
(note: does not remove NA
s before averaging).
In addition to these, it also includes the Overall Accuracy
and
the multiclass MCC
.
Note: Balanced Accuracy
is the macro-averaged metric,
not the macro sensitivity as sometimes used!
Other available metrics (disabled by default, see metrics
):
Accuracy
,
multiclass AUC
,
Weighted Balanced Accuracy
,
Weighted Accuracy
,
Weighted F1
,
Weighted Sensitivity
,
Weighted Sensitivity
,
Weighted Specificity
,
Weighted Pos Pred Value
,
Weighted Neg Pred Value
,
Weighted Kappa
,
Weighted Detection Rate
,
Weighted Detection Prevalence
, and
Weighted Prevalence
.
Note that the "Weighted" average metrics are weighted by the Support
.
Also includes:
A nested tibble
with the predictions, predicted classes, and targets.
A list of ROC curve objects when AUC
is enabled.
A nested tibble
with the multiclass Confusion Matrix.
Class Level Results
Besides the binomial evaluation metrics and the Support
,
the nested class level results tibble
also contains a
nested tibble
with the Confusion Matrix from the one-vs-all evaluation.
The Pos_
columns tells you whether a row is a
True Positive (TP
), True Negative (TN
),
False Positive (FP
), or False Negative (FN
),
depending on which level is the "positive" class. In our case, 1
is the current class
and 0
represents all the other classes together.
Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk
Other validation functions:
cross_validate()
,
cross_validate_fn()
,
validate()
# Attach packages
library(cvms)
library(groupdata2) # fold()
library(dplyr) # %>% arrange() mutate()
# Note: More examples of custom functions can be found at:
# model_fn: model_functions()
# predict_fn: predict_functions()
# preprocess_fn: preprocess_functions()
# Data is part of cvms
data <- participant.scores
# Set seed for reproducibility
set.seed(7)
# Fold data
data <- partition(
data,
p = 0.8,
cat_col = "diagnosis",
id_col = "participant",
list_out = FALSE
) %>%
mutate(diagnosis = as.factor(diagnosis)) %>%
arrange(.partitions)
# Formulas to validate
formula_gaussian <- "score ~ diagnosis"
formula_binomial <- "diagnosis ~ score"
#
# Gaussian
#
# Create model function that returns a fitted model object
lm_model_fn <- function(train_data, formula, hyperparameters) {
lm(formula = formula, data = train_data)
}
# Create predict function that returns the predictions
lm_predict_fn <- function(test_data, model, formula,
hyperparameters, train_data) {
stats::predict(
object = model,
newdata = test_data,
type = "response",
allow.new.levels = TRUE
)
}
# Validate the model function
v <- validate_fn(
data,
formulas = formula_gaussian,
type = "gaussian",
model_fn = lm_model_fn,
predict_fn = lm_predict_fn,
partitions_col = ".partitions"
)
v
# Extract model object
v$Model[[1]]
#
# Binomial
#
# Create model function that returns a fitted model object
glm_model_fn <- function(train_data, formula, hyperparameters) {
glm(formula = formula, data = train_data, family = "binomial")
}
# Create predict function that returns the predictions
glm_predict_fn <- function(test_data, model, formula,
hyperparameters, train_data) {
stats::predict(
object = model,
newdata = test_data,
type = "response",
allow.new.levels = TRUE
)
}
# Validate the model function
validate_fn(
data,
formulas = formula_binomial,
type = "binomial",
model_fn = glm_model_fn,
predict_fn = glm_predict_fn,
partitions_col = ".partitions"
)
#
# Support Vector Machine (svm)
# with known hyperparameters
#
# Only run if the `e1071` package is installed
if (requireNamespace("e1071", quietly = TRUE)){
# Create model function that returns a fitted model object
# We use the hyperparameters arg to pass in the kernel and cost values
# These will usually have been found with cross_validate_fn()
svm_model_fn <- function(train_data, formula, hyperparameters) {
# Expected hyperparameters:
# - kernel
# - cost
if (!"kernel" %in% names(hyperparameters))
stop("'hyperparameters' must include 'kernel'")
if (!"cost" %in% names(hyperparameters))
stop("'hyperparameters' must include 'cost'")
e1071::svm(
formula = formula,
data = train_data,
kernel = hyperparameters[["kernel"]],
cost = hyperparameters[["cost"]],
scale = FALSE,
type = "C-classification",
probability = TRUE
)
}
# Create predict function that returns the predictions
svm_predict_fn <- function(test_data, model, formula,
hyperparameters, train_data) {
predictions <- stats::predict(
object = model,
newdata = test_data,
allow.new.levels = TRUE,
probability = TRUE
)
# Extract probabilities
probabilities <- dplyr::as_tibble(
attr(predictions, "probabilities")
)
# Return second column
probabilities[[2]]
}
# Specify hyperparameters to use
# We found these in the examples in ?cross_validate_fn()
svm_hparams <- list(
"kernel" = "linear",
"cost" = 10
)
# Validate the model function
validate_fn(
data,
formulas = formula_binomial,
type = "binomial",
model_fn = svm_model_fn,
predict_fn = svm_predict_fn,
hyperparameters = svm_hparams,
partitions_col = ".partitions"
)
} # closes `e1071` package check
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.