# R/tune_grid.R In tune: Tidy Tuning Tools

#### Documented in pull_rset_attributestune_gridtune_grid.model_spectune_grid.workflow

#' Model tuning via grid search
#'
#' [tune_grid()] computes a set of performance metrics (e.g. accuracy or RMSE)
#'  for a pre-defined set of tuning parameters that correspond to a model or
#'  recipe across one or more resamples of the data.
#'
#' @param object A parsnip model specification or a [workflows::workflow()].
#' @param preprocessor A traditional model formula or a recipe created using
#'   [recipes::recipe()].
#' @param resamples An rset() object.
#' @param param_info A [dials::parameters()] object or NULL. If none is given,
#' a parameters set is derived from other arguments. Passing this argument can
#' be useful when parameter ranges need to be customized.
#' @param grid A data frame of tuning combinations or a positive integer. The
#'  data frame should have columns for each parameter being tuned and rows for
#'  tuning parameter candidates. An integer denotes the number of candidate
#'  parameter sets to be created automatically.
#' @param metrics A [yardstick::metric_set()] or NULL.
#' @param control An object used to modify the tuning process.
#' @param ... Not currently used.
#' @return An updated version of resamples with extra list columns for .metrics and
#' .notes (optional columns are .predictions and .extracts). .notes
#' contains warnings and errors that occur during execution.
#' @seealso [control_grid()], [tune()], [fit_resamples()],
#' [autoplot.tune_results()], [show_best()], [select_best()],
#' [collect_predictions()], [collect_metrics()]
#' @details
#'
#' Suppose there are _m_ tuning parameter combinations. [tune_grid()] may not
#' require all _m_ model/recipe fits across each resample. For example:
#'
#' \itemize{
#'   \item In cases where a single model fit can be used to make predictions
#'         for different parameter values in the grid, only one fit is used.
#'         For example, for some boosted trees, if 100 iterations of boosting
#'         are requested, the model object for 100 iterations can be used to
#'         make predictions on iterations less than 100 (if all other
#'         parameters are equal).
#'   \item When the model is being tuned in conjunction with pre-processing
#'         and/or post-processing parameters, the minimum number of fits are
#'         used. For example, if the number of PCA components in a recipe step
#'         are being tuned over three values (along with model tuning
#'         parameters), only three recipes are trained. The alternative
#'         would be to re-train the same recipe multiple times for each model
#'         tuning parameter.
#' }
#'
#' The foreach package is used here. To execute the resampling iterations in
#' parallel, register a parallel backend function. See the documentation for
#' [foreach::foreach()] for examples.
#'
#' For the most part, warnings generated during training are shown as they occur
#' and are associated with a specific resample when control(verbose = TRUE).
#' They are (usually) not aggregated until the end of processing.
#'
#' @section Parameter Grids:
#'
#' If no tuning grid is provided, a semi-random grid (via
#' [dials::grid_latin_hypercube()]) is created with 10 candidate parameter
#' combinations.
#'
#' When provided, the grid should have column names for each parameter and
#'  these should be named by the parameter name or id. For example, if a
#'  parameter is marked for optimization using penalty = tune(), there should
#'  be a column named penalty. If the optional identifier is used, such as
#'  penalty = tune(id = 'lambda'), then the corresponding column name should
#'  be lambda.
#'
#' In some cases, the tuning parameter values depend on the dimensions of the
#'  data. For example, mtry in random forest models depends on the number of
#'  predictors. In this case, the default tuning parameter object requires an
#'  upper range. [dials::finalize()] can be used to derive the data-dependent
#'  parameters. Otherwise, a parameter set can be created (via
#'  [dials::parameters()]) and the dials update() function can be used to
#'  change the values. This updated parameter set can be passed to the function
#'  via the param_info argument.
#'
#' @section Performance Metrics:
#'
#' To use your own performance metrics, the [yardstick::metric_set()] function
#'  can be used to pick what should be measured for each model. If multiple
#'  metrics are desired, they can be bundled. For example, to estimate the area
#'  under the ROC curve as well as the sensitivity and specificity (under the
#'  typical probability cutoff of 0.50), the metrics argument could be given:
#'
#' \preformatted{
#'   metrics = metric_set(roc_auc, sens, spec)
#' }
#'
#' Each metric is calculated for each candidate model.
#'
#' If no metric set is provided, one is created:
#' \itemize{
#'   \item For regression models, the root mean squared error and coefficient
#'         of determination are computed.
#'   \item For classification, the area under the ROC curve and overall accuracy
#'         are computed.
#' }
#'
#' Note that the metrics also determine what type of predictions are estimated
#' during tuning. For example, in a classification problem, if metrics are used
#' that are all associated with hard class predictions, the classification
#' probabilities are not created.
#'
#' The out-of-sample estimates of these metrics are contained in a list column
#' called .metrics. This tibble contains a row for each metric and columns
#' for the value, the estimator type, and so on.
#'
#' [collect_metrics()] can be used for these objects to collapse the results
#' over the resampled (to obtain the final resampling estimates per tuning
#' parameter combination).
#'
#' @section Obtaining Predictions:
#'
#' When control(save_preds = TRUE), the output tibble contains a list column
#'  called .predictions that has the out-of-sample predictions for each
#'  parameter combination in the grid and each fold (which can be very large).
#'
#' The elements of the tibble are tibbles with columns for the tuning
#' parameters, the row number from the original data object (.row), the
#' outcome data (with the same name(s) of the original data), and any columns
#' created by the predictions. For example, for simple regression problems, this
#' function generates a column called .pred and so on. As noted above, the
#' prediction columns that are returned are determined by the type of metric(s)
#' requested.
#'
#' This list column can be unnested using [tidyr::unnest()] or using the
#'  convenience function [collect_predictions()].
#'
#' @section Extracting Information:
#'
#' The extract control option will result in an additional function to be
#'  returned called .extracts. This is a list column that has tibbles
#'  containing the results of the user's function for each tuning parameter
#'  combination. This can enable returning each model and/or recipe object that
#'  is created during resampling. Note that this could result in a large return
#'  object, depending on what is returned.
#'
#' The control function contains an option (extract) that can be used to
#'  retain any model or recipe that was created within the resamples. This
#'  argument should be a function with a single argument. The value of the
#'  argument that is given to the function in each resample is a workflow
#'  helper functions that can be used to easily pull out the recipe (if any)
#'  and/or the model: [extract_recipe()] and [extract_model()].
#'
#' As an example, if there is interest in getting each model back, one could use:
#' \preformatted{
#'   extract = function (x) extract_fit_parsnip(x)
#' }
#'
#' Note that the function given to the extract argument is evaluated on
#'  every model that is _fit_ (as opposed to every model that is _evaluated_).
#' As noted above, in some cases, model predictions can be derived for
#'  sub-models so that, in these cases, not every row in the tuning parameter
#'  grid has a separate R object associated with it.
#' @examples
#' \donttest{
#' library(recipes)
#' library(rsample)
#' library(parsnip)
#' library(workflows)
#' library(ggplot2)
#'
#' # ---------------------------------------------------------------------------
#'
#' set.seed(6735)
#' folds <- vfold_cv(mtcars, v = 5)
#'
#' # ---------------------------------------------------------------------------
#'
#' # tuning recipe parameters:
#'
#' spline_rec <-
#'   recipe(mpg ~ ., data = mtcars) %>%
#'   step_ns(disp, deg_free = tune("disp")) %>%
#'   step_ns(wt, deg_free = tune("wt"))
#'
#' lin_mod <-
#'   linear_reg() %>%
#'   set_engine("lm")
#'
#' # manually create a grid
#' spline_grid <- expand.grid(disp = 2:5, wt = 2:5)
#'
#' # Warnings will occur from making spline terms on the holdout data that are
#' # extrapolations.
#' spline_res <-
#'   tune_grid(lin_mod, spline_rec, resamples = folds, grid = spline_grid)
#' spline_res
#'
#'
#' show_best(spline_res, metric = "rmse")
#'
#' # ---------------------------------------------------------------------------
#'
#' # tune model parameters only (example requires the kernlab package)
#'
#' car_rec <-
#'   recipe(mpg ~ ., data = mtcars) %>%
#'   step_normalize(all_predictors())
#'
#' svm_mod <-
#'   svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
#'   set_engine("kernlab") %>%
#'   set_mode("regression")
#'
#' # Use a space-filling design with 7 points
#' set.seed(3254)
#' svm_res <- tune_grid(svm_mod, car_rec, resamples = folds, grid = 7)
#' svm_res
#'
#' show_best(svm_res, metric = "rmse")
#'
#' autoplot(svm_res, metric = "rmse") +
#'   scale_x_log10()
#'
#' # ---------------------------------------------------------------------------
#'
#' # Using a variables preprocessor with a workflow
#'
#' # Rather than supplying a preprocessor (like a recipe) and a model directly
#' # to tune_grid(), you can also wrap them up in a workflow and pass
#' # that along instead (note that this doesn't do any preprocessing to
#' # the variables, it passes them along as-is).
#' wf <- workflow() %>%
#'   add_variables(outcomes = mpg, predictors = everything()) %>%
#'
#' set.seed(3254)
#' svm_res_wf <- tune_grid(wf, resamples = folds, grid = 7)
#' }
#' @export
tune_grid <- function(object, ...) {
UseMethod("tune_grid")
}

#' @export
tune_grid.default <- function(object, ...) {
msg <- paste0(
"The first argument to [tune_grid()] should be either ",
"a model or workflow."
)
rlang::abort(msg)
}

#' @export
#' @rdname tune_grid
tune_grid.model_spec <- function(object, preprocessor, resamples, ...,
param_info = NULL, grid = 10, metrics = NULL,
control = control_grid()) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
"with a formula or recipe"
))
}

empty_ellipses(...)

if (is_recipe(preprocessor)) {
} else if (rlang::is_formula(preprocessor)) {
}

tune_grid(
wflow,
resamples = resamples,
param_info = param_info,
grid = grid,
metrics = metrics,
control = control
)
}

#' @export
#' @rdname tune_grid
tune_grid.workflow <- function(object, resamples, ..., param_info = NULL,
grid = 10, metrics = NULL, control = control_grid()) {
empty_ellipses(...)

# Disallow NULL grids in tune_grid(), as this is the special signal
# used when no tuning is required
if (is.null(grid)) {
rlang::abort(grid_msg)
}

tune_grid_workflow(
object,
resamples = resamples,
grid = grid,
metrics = metrics,
pset = param_info,
control = control
)
}

# ------------------------------------------------------------------------------

tune_grid_workflow <- function(workflow,
resamples,
grid = 10,
metrics = NULL,
pset = NULL,
control = control_grid(),
rng = TRUE) {
check_rset(resamples)

metrics <- check_metrics(metrics, workflow)

pset <- check_parameters(
workflow = workflow,
pset = pset,
data = resamples$splits[[1]]$data,
grid_names = names(grid)
)

check_workflow(workflow, pset = pset)

grid <- check_grid(
grid = grid,
workflow = workflow,
pset = pset
)

# Save rset attributes, then fall back to a bare tibble
rset_info <- pull_rset_attributes(resamples)
resamples <- new_bare_tibble(resamples)

resamples <- tune_grid_loop(
resamples = resamples,
grid = grid,
workflow = workflow,
metrics = metrics,
control = control,
rng = rng
)

if (is_cataclysmic(resamples)) {
rlang::warn("All models failed. See the .notes column.")
}

outcomes <- reduce_all_outcome_names(resamples)
resamples[[".all_outcome_names"]] <- NULL

workflow <- set_workflow(workflow, control)

new_tune_results(
x = resamples,
parameters = pset,
metrics = metrics,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow
)
}

# ------------------------------------------------------------------------------

#' @export
#' @keywords internal
#' @rdname empty_ellipses
pull_rset_attributes <- function(x) {
excl_att <- c("names", "row.names")
att <- attributes(x)
att_nms <- names(att)
att_nms <- setdiff(att_nms, excl_att)
att$class <- setdiff(class(x), class(tibble::tibble())) att$class <- att$class[att$class != "rset"]

lab <- try(pretty(x), silent = TRUE)
if (inherits(lab, "try-error")) {
lab <- NA_character_
}
list(att = att[att_nms], label = lab)
}

# ------------------------------------------------------------------------------

set_workflow <- function(workflow, control) {
if (control$save_workflow) { if (!is.null(workflow$pre$actions$recipe)) {
w_size <- utils::object.size(workflow$pre$actions$recipe) # make 5MB cutoff if (w_size / 1024^2 > 5) { msg <- paste0( "The workflow being saved contains a recipe, which is ", format(w_size, units = "Mb", digits = 2), " in memory. If this was not intentional, please set the control ", "setting save_workflow = FALSE." ) cols <- get_tune_colors() msg <- strwrap(msg, prefix = paste0(cols$symbol$info(cli::symbol$info), " "))
msg <- cols$message$info(paste0(msg, collapse = "\n"))
rlang::inform(msg)
}
}
workflow
} else {
NULL
}
}


## Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on March 19, 2022, 2:14 a.m.