#' Perform cross-validation on a nested_cv rsample object
#'
#' @param resamples nested_cv object
#' @param object parsnip model specification, or a pipeline object
#' @param recipe recipe, optional if using pipeline object
#' @param scoring yardstick scoring function
#' @param keep_preds logical, store predictions per fold in a `predictions` column,
#' default is TRUE
#' @param keep_models logical, store fitted models per fold in a `models` column, default is FALSE
#' @param .options future_options to pass additional packages required by tuning and fitting
#' functions
#'
#' @return tibble
#' @export
#' @importFrom furrr future_options
cross_validate <- function(resamples, object, recipe = NULL, scoring, keep_preds = TRUE,
keep_models = FALSE, .options = future_options()) {
UseMethod("cross_validate", resamples)
}
#' @export
#' @importFrom purrr map
#' @importFrom furrr future_map2 future_options
#' @importFrom dplyr bind_cols mutate
#' @importFrom formula.tools lhs.vars
#' @importFrom rsample analysis assessment
#' @importFrom tibble as_tibble tibble
cross_validate.default <- function(resamples, object, recipe = NULL, scoring, keep_preds = TRUE,
keep_models = FALSE, .options = future_options()) {
if (inherits(object, "pipeline")) {
estimator <- object[[2]]
} else {
estimator <- object
}
truth_col <- switch(
estimator$mode,
"classification" = ".pred_class",
"regression" = ".pred"
)
# create train and test data splits
resamples <- resamples %>%
mutate(training = map(resamples$splits, analysis),
testing = map(resamples$splits, assessment))
results <- future_map2(resamples$training, resamples$testing, function(X_train, X_test) {
if (!inherits(object, "pipeline"))
object <- pipeline(preprocessing = recipe, model_spec = object)
fitted <- object %>% fit(X_train)
preds <- fitted %>% predict(new_data = X_test)
scores <- preds %>%
bind_cols(X_test) %>%
scoring(
truth = !!lhs.vars(formula(fitted$recipe)),
estimate = !!truth_col)
if (keep_preds == TRUE)
preds_tbl <- preds else preds_tbl = NULL
if (keep_models == TRUE)
models_tbl <- fitted else models_tbl = NULL
list(scores = scores,
predictions = preds_tbl,
models = models_tbl)
}, .options = .options)
results <- set_names(results, seq_along(results))
resamples$outer_scores <- map(results, ~ .x$scores)
if (keep_preds == TRUE)
resamples$predictions <- map(results, ~ .x$predictions)
if (keep_models == TRUE)
resamples$models <- map(results, ~ .x$models)
resamples
}
#' @export
#' @importFrom purrr map
#' @importFrom furrr future_pmap future_options
#' @importFrom dplyr bind_cols mutate
#' @importFrom formula.tools lhs.vars
#' @importFrom rsample analysis assessment
#' @importFrom tibble as_tibble tibble
cross_validate.nested_cv <- function(resamples, object, recipe = NULL, scoring,
keep_preds = TRUE, keep_models = FALSE,
.options = future_options()) {
if (inherits(object, "pipeline")) {
estimator <- object[[2]]
} else {
estimator <- object
}
# convert best hyperparameter columns into a list
hyper_pars_names <- attr(resamples, "tuning")
pars <- lapply(seq_len(nrow(resamples)), function(i) {
vars <- list()
for (k in resamples[i, hyper_pars_names])
vars <- append(vars, k)
names(vars) <- hyper_pars_names
vars
})
# determine truth column
truth_col <- switch(
estimator$mode,
"classification" = ".pred_class",
"regression" = ".pred"
)
resamples <- resamples %>%
mutate(training = map(resamples$splits, analysis),
testing = map(resamples$splits, assessment))
results <- future_pmap(
list(resamples$training, resamples$testing, pars), function(X_train, X_test, par) {
if (!inherits(object, "pipeline")) {
object <- exec(pipeline, preprocessing = recipe, model_spec = object, !!!par)
} else {
object <- object %>% update(!!!par)
}
fitted <- object %>% fit(X_train)
preds <- fitted %>% predict(new_data = X_test)
scores <- preds %>%
bind_cols(X_test) %>%
scoring(
truth = !!lhs.vars(formula(fitted$recipe)),
estimate = !!truth_col)
if (keep_preds == TRUE)
preds_tbl <- preds else preds_tbl = NULL
if (keep_models == TRUE)
models_tbl <- fitted else models_tbl = NULL
list(scores = scores,
predictions = preds_tbl,
models = models_tbl)
}, .options = .options
)
results <- set_names(results, seq_along(results))
resamples$outer_scores <- map(results, ~ .x$scores)
if (keep_preds == TRUE)
resamples$predictions <- map(results, ~ .x$predictions)
if (keep_models == TRUE)
resamples$models <- map(results, ~ .x$models)
attr(resamples, "tuning") <- hyper_pars_names
resamples
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.