R/control.R

Defines functions new_backend_options val_parallel_over print.control_bayes control_bayes print.control_last_fit control_last_fit print.control_grid control_grid

Documented in control_bayes control_grid control_last_fit new_backend_options

#' Control aspects of the grid search process
#'
#' @inheritParams control_bayes
#'
#' @inheritSection collect_predictions Hyperparameters and extracted objects
#'
#' @details
#'
#' For `extract`, this function can be used to output the model object, the
#' recipe (if used), or some components of either or both. When evaluated, the
#' function's sole argument has a fitted workflow If the formula method is used,
#' the recipe element will be `NULL`.
#'
#' The results of the `extract` function are added to a list column in the
#' output called `.extracts`. Each element of this list is a tibble with tuning
#' parameter column and a list column (also called `.extracts`) that contains
#' the results of the function. If no extraction function is used, there is no
#' `.extracts` column in the resulting object. See [tune_bayes()] for more
#' specific details.
#'
#' Note that for [collect_predictions()], it is possible that each row of the
#'  original data point might be represented multiple times per tuning
#'  parameter. For example, if the bootstrap or repeated cross-validation are
#'  used, there will be multiple rows since the sample data point has been
#'  evaluated multiple times. This may cause issues when merging the predictions
#'  with the original data.
#'
#' [control_resamples()] is an alias for [control_grid()] and is meant to be
#' used with [fit_resamples()].
#' @export
control_grid <- function(verbose = FALSE, allow_par = TRUE,
                         extract = NULL, save_pred = FALSE,
                         pkgs = NULL, save_workflow = FALSE,
                         event_level = "first",
                         parallel_over = NULL,
                         backend_options = NULL) {

  # Any added arguments should also be added in superset control functions
  # in other packages

  # add options for  seeds per resample

  val_class_and_single(verbose, "logical", "control_grid()")
  val_class_and_single(allow_par, "logical", "control_grid()")
  val_class_and_single(save_pred, "logical", "control_grid()")
  val_class_and_single(save_workflow, "logical", "control_grid()")
  val_class_and_single(event_level, "character", "control_grid()")
  val_class_or_null(pkgs, "character", "control_grid()")
  val_class_or_null(extract, "function", "control_grid()")
  val_parallel_over(parallel_over, "control_grid()")


  res <- list(
    verbose = verbose,
    allow_par = allow_par,
    extract = extract,
    save_pred = save_pred,
    pkgs = pkgs,
    save_workflow = save_workflow,
    event_level = event_level,
    parallel_over = parallel_over,
    backend_options = backend_options
  )

  class(res) <- c("control_grid", "control_resamples")
  res
}

#' @export
print.control_grid <- function(x, ...) {
  cat("grid/resamples control object\n")
  invisible(x)
}

#' @rdname control_grid
#' @export
control_resamples <- control_grid

#' Control aspects of the last fit process
#'
#' @inheritParams control_grid
#'
#' @details
#'
#' [control_last_fit()] is a wrapper around [control_resamples()] and is meant
#'   to be used with [last_fit()].
#'
#' @export
control_last_fit <- function(
    verbose = FALSE,
    event_level = "first",
    allow_par = FALSE
) {
  # Any added arguments should also be added in superset control functions
  # in other packages

  extr <- function(x) x
  control <-
    control_resamples(
      verbose = verbose,
      allow_par = allow_par,
      event_level = event_level,
      extract = extr,
      save_pred = TRUE,
      save_workflow = FALSE
    )
  class(control) <- c("control_last_fit", class(control))
  control
}

#' @export
print.control_last_fit <- function(x, ...) {
  cat("last fit control object\n")
  invisible(x)
}

# ------------------------------------------------------------------------------

#' Control aspects of the Bayesian search process
#'
#' @param verbose A logical for logging results (other than warnings and errors,
#'   which are always shown) as they are generated during training in a single
#'   R process. When using most parallel backends, this argument typically will
#'   not result in any logging. If using a dark IDE theme, some logging messages
#'   might be hard to see; try setting the `tidymodels.dark` option with
#'   `options(tidymodels.dark = TRUE)` to print lighter colors.
#' @param verbose_iter A logical for logging results of the Bayesian search
#'   process. Defaults to FALSE. If using a dark IDE theme, some logging
#'   messages might be hard to see; try setting the `tidymodels.dark` option
#'   with `options(tidymodels.dark = TRUE)` to print lighter colors.
#' @param no_improve The integer cutoff for the number of iterations without
#'   better results.
#' @param uncertain The number of iterations with no improvement before an
#'  uncertainty sample is created where a sample with high predicted variance is
#'  chosen (i.e., in a region that has not yet been explored). The iteration
#'  counter is reset after each uncertainty sample. For example, if `uncertain =
#'  10`, this condition is triggered every 10 samples with no improvement.
#' @param seed An integer for controlling the random number stream. Tuning
#' functions are sensitive to both the state of RNG set outside of tuning
#' functions with `set.seed()` as well as the value set here. The value of the
#' former determines RNG for the higher-level tuning process, like grid
#' generation and setting the value of this argument if left as default. The
#' value of this argument determines RNG state in workers for each iteration
#' of model fitting, determined by the value of `parallel_over`.
#' @param time_limit A number for the minimum number of _minutes_ (elapsed) that
#'   the function should execute. The elapsed time is evaluated at internal
#'   checkpoints and, if over time, the results at that time are returned (with
#'   a warning). This means that the `time_limit` is not an exact limit, but a
#'   minimum time limit.
#'
#'   Note that timing begins immediately on execution. Thus, if the
#'   `initial` argument to [tune_bayes()] is supplied as a number, the elapsed
#'   time will include the time needed to generate initialization results.
#' @param extract An optional function with at least one argument (or `NULL`)
#'   that can be used to retain arbitrary objects from the model fit object,
#'   recipe, or other elements of the workflow.
#' @param save_pred A logical for whether the out-of-sample predictions should
#'   be saved for each model _evaluated_.
#' @param pkgs An optional character string of R package names that should be
#'   loaded (by namespace) during parallel processing.
#' @param save_workflow A logical for whether the workflow should be appended
#'  to the output as an attribute.
#' @param save_gp_scoring A logical to save the intermediate Gaussian process
#'   models for each iteration of the search. These are saved to
#'  `tempdir()` with names `gp_candidates_{i}.RData` where `i` is the iteration.
#'  These results are deleted when the R session ends. This option is only
#'  useful for teaching purposes.
#' @param event_level A single string containing either `"first"` or `"second"`.
#'   This argument is passed on to yardstick metric functions when any type
#'   of class prediction is made, and specifies which level of the outcome
#'   is considered the "event".
#' @param parallel_over A single string containing either `"resamples"` or
#'   `"everything"` describing how to use parallel processing. Alternatively,
#'   `NULL` is allowed, which chooses between `"resamples"` and `"everything"`
#'   automatically.
#'
#'   If `"resamples"`, then tuning will be performed in parallel over resamples
#'   alone. Within each resample, the preprocessor (i.e. recipe or formula) is
#'   processed once, and is then reused across all models that need to be fit.
#'
#'   If `"everything"`, then tuning will be performed in parallel at two levels.
#'   An outer parallel loop will iterate over resamples. Additionally, an
#'   inner parallel loop will iterate over all unique combinations of
#'   preprocessor and model tuning parameters for that specific resample. This
#'   will result in the preprocessor being re-processed multiple times, but
#'   can be faster if that processing is extremely fast.
#'
#'   If `NULL`, chooses `"resamples"` if there are more than one resample,
#'   otherwise chooses `"everything"` to attempt to maximize core utilization.
#'
#'   Note that switching between `parallel_over` strategies is not guaranteed
#'   to use the same random number generation schemes. However, re-tuning a
#'   model using the same `parallel_over` strategy is guaranteed to be
#'   reproducible between runs.
#' @param backend_options An object of class `"tune_backend_options"` as created
#'   by `tune::new_backend_options()`, used to pass arguments to specific tuning
#'   backend. Defaults to `NULL` for default backend options.
#' @param allow_par A logical to allow parallel processing (if a parallel
#'   backend is registered).
#'
#' @inheritSection collect_predictions Hyperparameters and extracted objects
#'
#' @details
#'
#' For `extract`, this function can be used to output the model object, the
#' recipe (if used), or some components of either or both. When evaluated, the
#' function's sole argument has a fitted workflow If the formula method is used,
#' the recipe element will be `NULL`.
#'
#' The results of the `extract` function are added to a list column in the
#' output called `.extracts`. Each element of this list is a tibble with tuning
#' parameter column and a list column (also called `.extracts`) that contains
#' the results of the function. If no extraction function is used, there is no
#' `.extracts` column in the resulting object. See [tune_bayes()] for more
#' specific details.
#'
#' Note that for [collect_predictions()], it is possible that each row of the
#'  original data point might be represented multiple times per tuning
#'  parameter. For example, if the bootstrap or repeated cross-validation are
#'  used, there will be multiple rows since the sample data point has been
#'  evaluated multiple times. This may cause issues when merging the predictions
#'  with the original data.
#' @export
control_bayes <-
  function(verbose = FALSE,
           verbose_iter = FALSE,
           no_improve = 10L,
           uncertain = Inf,
           seed = sample.int(10^5, 1),
           extract = NULL,
           save_pred = FALSE,
           time_limit = NA,
           pkgs = NULL,
           save_workflow = FALSE,
           save_gp_scoring = FALSE,
           event_level = "first",
           parallel_over = NULL,
           backend_options = NULL,
           allow_par = TRUE) {
    # Any added arguments should also be added in superset control functions
    # in other packages

    # add options for seeds per resample

    val_class_and_single(verbose, "logical", "control_bayes()")
    val_class_and_single(verbose_iter, "logical", "control_bayes()")
    val_class_and_single(save_pred, "logical", "control_bayes()")
    val_class_and_single(save_gp_scoring, "logical", "control_bayes()")
    val_class_and_single(save_workflow, "logical", "control_bayes()")
    val_class_and_single(no_improve, c("numeric", "integer"), "control_bayes()")
    val_class_and_single(uncertain, c("numeric", "integer"), "control_bayes()")
    val_class_and_single(seed, c("numeric", "integer"), "control_bayes()")
    val_class_or_null(extract, "function", "control_bayes()")
    val_class_and_single(time_limit, c("logical", "numeric"), "control_bayes()")
    val_class_or_null(pkgs, "character", "control_bayes()")
    val_class_and_single(event_level, "character", "control_bayes()")
    val_parallel_over(parallel_over, "control_bayes()")
    val_class_and_single(allow_par, "logical", "control_bayes()")


    if (!is.infinite(uncertain) && uncertain > no_improve) {
      cli::cli_alert_warning(
        "Uncertainty sample scheduled after {uncertain} poor iterations but the search will stop after {no_improve}."
      )
    }

    res <-
      list(
        verbose = verbose,
        verbose_iter = verbose_iter,
        allow_par = allow_par,
        no_improve = no_improve,
        uncertain = uncertain,
        seed = seed,
        extract = extract,
        save_pred = save_pred,
        time_limit = time_limit,
        pkgs = pkgs,
        save_workflow = save_workflow,
        save_gp_scoring = save_gp_scoring,
        event_level = event_level,
        parallel_over = parallel_over,
        backend_options = backend_options
      )

    class(res) <- "control_bayes"
    res
  }

#' @export
print.control_bayes <- function(x, ...) {
  cat("bayes control object\n")
  invisible(x)
}

# ------------------------------------------------------------------------------

val_parallel_over <- function(parallel_over, where) {
  if (is.null(parallel_over)) {
    return(invisible(NULL))
  }

  val_class_and_single(parallel_over, "character", where)
  rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")

  invisible(NULL)
}

#' @export
#' @keywords internal
#' @rdname control_grid
new_backend_options <- function(..., class = character()) {
  out <- rlang::list2(...)

  if (any(rlang::names2(out) == "")) {
    rlang::abort("All backend options must be named.")
  }

  structure(out, class = c(class, "tune_backend_options"))
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.