#' Hyperparameter tuning on rsample objects
#'
#' `tune`` is a generic function that accepts a tibble of resampling partitions generated by the
#' resampling schemes available in the `rsample` package. A `parsnip` model specification and a
#' `recipes` recipe also need to be supplied to this function, along with a `yardstick` scoring
#' function. The `param_grid` function accepts a `grid_regular` or `grid_random` object of tuning
#' parameters generated by the `dials` package. Hyperparameter tuning is then performed on this
#' object, which is returned as a tibble with an additional list-column called `tune_scores`
#' containing the tuning scores.
#'
#' If the resampling object represents a `nested_cv` object, then hyperparameter tuning is performed
#' on the inner resampling partitions, and best best hyperparameters per outer fold are also
#' returned as additional columns in the output tibble.
#'
#' @param resamples tibble generated by rsample containing resamplign scheme
#' @param object parsnip model specification, or a pipeline object
#' @param recipe recipe object, optional if using a pipeline
#' @param param_grid dials grid_regular or grid_random object
#' @param scoring yardstick scoring function
#' @param maximize logical, maximize the scoring function, default = TRUE
#' @param .options future_options to pass additional packages required by tuning and fitting
#' functions
#'
#' @return tibble containing resampling results
#' @export
#' @importFrom furrr future_options
tune <- function(resamples, object, recipe = NULL, param_grid, scoring, maximize = TRUE,
.options = future_options()) {
UseMethod("tune", resamples)
}
#' @importFrom purrr map map_dbl pmap
#' @importFrom furrr future_map2_dbl future_options
#' @importFrom dplyr bind_cols select rename group_by summarize
#' @importFrom tidyr expand_grid nest unnest
#' @importFrom rlang sym
#' @export
tune.default <- function(resamples, object, recipe = NULL, param_grid, scoring,
maximize = TRUE, .options = future_options()) {
arg_names <- names(param_grid)
# cross with the param_grid and pipelines
inner_resamples <- resamples %>%
expand_grid(param_grid)
if (inherits(object, "pipeline")) {
inner_resamples$pipelines <- inner_resamples %>%
select(!!arg_names) %>%
pmap(function(...) object %>% update(!!!list2(...)))
} else {
inner_resamples$pipelines <- inner_resamples %>%
select(!!arg_names) %>%
pmap(function(...) pipeline(recipe, object, !!!list2(...)))
}
# fit and score all inner resamples
inner_resamples$score <- future_map2_dbl(
inner_resamples$splits, inner_resamples$pipelines,
fit_and_score, scoring,
.options = .options
)
# get best scoring hyperparameter per outer fold
which_fun <- if (isTRUE(maximize)) which.max else which.min
scores_per_fold <- inner_resamples %>%
group_by(!!sym("id")) %>%
filter(row_number() == which_fun(!!sym("score"))) %>%
ungroup()
# bind best hyperparameters per outer fold back onto resamples
resamples <- bind_cols(resamples, scores_per_fold %>% select(!!arg_names))
# nest scores back onto resampling df
resamples$tune_scores <- map(
split(inner_resamples, as.factor(inner_resamples$id)), function(x, ...) {
x <- x %>%
select(-!!sym("splits")) %>%
nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
x <- x$tune_scores
x[[1]]
})
attr(resamples, "tuning") <- names(param_grid)
class(resamples) <- append(class(resamples), "tune")
resamples
}
#' @importFrom purrr map map_dbl pmap
#' @importFrom furrr future_map2_dbl future_options
#' @importFrom dplyr bind_cols select rename group_by group_by_at summarize group_map
#' @importFrom tidyr expand_grid nest unnest
#' @importFrom rlang sym
#' @export
tune.nested_cv <- function(resamples, object, recipe = NULL, param_grid, scoring,
maximize = TRUE, .options = future_options()) {
arg_names <- names(param_grid)
# flatten the resamples
inner_resamples <- resamples %>%
select(!!sym("id"), !!sym("inner_resamples")) %>%
rename(outer_fold = !!sym("id")) %>%
unnest(!!sym("inner_resamples")) %>%
rename(inner_fold = !!sym("id"))
# cross with the param_grid and pipelines
inner_resamples <- inner_resamples %>%
expand_grid(param_grid)
if (inherits(object, "pipeline")) {
inner_resamples$pipelines <- inner_resamples %>%
select(!!arg_names) %>%
pmap(function(...) object %>% update(!!!list2(...)))
} else {
inner_resamples$pipelines <- inner_resamples %>%
select(!!arg_names) %>%
pmap(function(...) pipeline(recipe, object, !!!list2(...)))
}
# fit and score all inner resamples
inner_resamples$score <- future_map2_dbl(
inner_resamples$splits, inner_resamples$pipelines,
fit_and_score, scoring,
.options = .options
)
# summarize mean of inner resamples per outer fold/param
scores <- inner_resamples %>%
group_by_at(c(arg_names, "outer_fold")) %>%
summarize(score = mean(!!sym("score")))
# get best scoring hyperparameter per outer fold
which_fun <- if (isTRUE(maximize)) which.max else which.min
scores_per_fold <- scores %>%
group_by(!!sym("outer_fold")) %>%
filter(row_number() == which_fun(!!sym("score"))) %>%
ungroup() %>%
select(-!!sym("outer_fold"))
# bind best hyperparameters per outer fold back onto resamples
resamples <- bind_cols(resamples, scores_per_fold %>% select(!!arg_names))
# nest scores back onto resampling df
# resamples$tune_scores <- map(
# split(inner_resamples, as.factor(inner_resamples$outer_fold)), function(x, ...) {
# x <- x %>%
# select(-!!sym("splits")) %>%
# nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
# id <- x$inner_fold
# x <- x$tune_scores
# names(x) <- id
# x
# })
resamples$tune_scores <- inner_resamples %>%
group_by(!!sym("outer_fold")) %>%
group_map(function(x, ...) {
x <- x %>%
select(-!!sym("splits")) %>%
nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
id <- x$inner_fold
x <- x$tune_scores
names(x) <- id
x
})
attr(resamples, "tuning") <- names(param_grid)
class(resamples) <- append(class(resamples), "tune")
resamples
}
#' @importFrom stats formula predict
#' @importFrom rsample assessment form_pred
#' @importFrom formula.tools lhs.vars
#' @importFrom dplyr filter
fit_and_score <- function(rsplit, pipeline, scoring) {
# fit model to recipe
fitted <- pipeline %>% fit(data = analysis(rsplit))
# subset assessment set
X_test <- assessment(rsplit)
outcome_name <- setdiff(all.vars(formula(fitted$recipe)), form_pred(formula(fitted$recipe)))
# predict assessment set
pred <- fitted %>% predict(X_test)
pred <- pred %>% bind_cols(X_test)
# scoring
truth_col <- switch(
pipeline$model_spec$mode,
"classification" = ".pred_class",
"regression" = ".pred"
)
score <- pred %>% scoring(
truth = !!outcome_name,
estimate = !!truth_col)
score[[".estimate"]]
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.