R/tune_bayes.R

Defines functions save_gp_results reup_rs check_time is_cataclysmic more_results initial_info update_score_card pick_candidate pred_gp fit_gp encode_set create_initial_set tune_bayes_workflow tune_bayes.workflow tune_bayes.model_spec tune_bayes.default tune_bayes

Documented in check_time encode_set tune_bayes tune_bayes.model_spec tune_bayes.workflow

#' Bayesian optimization of model parameters.
#'
#' [tune_bayes()] uses models to generate new candidate tuning parameter
#'  combinations based on previous results.
#'
#' @inheritParams tune_grid
#' @param metrics A [yardstick::metric_set()] object containing information on how
#' models will be evaluated for performance. The first metric in `metrics` is the
#' one that will be optimized.
#' @param iter The maximum number of search iterations.
#' @param objective A character string for what metric should be optimized or
#' an acquisition function object.
#' @param initial An initial set of results in a tidy format (as would result
#' from [tune_grid()]) or a positive integer. It is suggested that the number of
#' initial results be greater than the number of parameters being optimized.
#' @param control A control object created by [control_bayes()]
#' @param ... Options to pass to [GPfit::GP_fit()] (mostly for the `corr` argument).
#' @return A tibble of results that mirror those generated by [tune_grid()].
#' However, these results contain an `.iter` column and replicate the `rset`
#' object multiple times over iterations (at limited additional memory costs).
#' @seealso [control_bayes()], [tune()], [autoplot.tune_results()],
#'  [show_best()], [select_best()], [collect_predictions()],
#'  [collect_metrics()], [prob_improve()], [exp_improve()], [conf_bound()],
#'  [fit_resamples()]
#' @details
#'
#' The optimization starts with a set of initial results, such as those
#'  generated by [tune_grid()]. If none exist, the function will create several
#'  combinations and obtain their performance estimates.
#'
#' Using one of the performance estimates as the _model outcome_, a Gaussian
#'  process (GP) model is created where the previous tuning parameter combinations
#'  are used as the predictors.
#'
#' A large grid of potential hyperparameter combinations is predicted using
#'  the model and scored using an _acquisition function_. These functions
#'  usually combine the predicted mean and variance of the GP to decide the best
#'  parameter combination to try next. For more information, see the
#'  documentation for [exp_improve()] and the corresponding package vignette.
#'
#' The best combination is evaluated using resampling and the process continues.
#'
#' @section Parallel Processing:
#'
#' The `foreach` package is used here. To execute the resampling iterations in
#' parallel, register a parallel backend function. See the documentation for
#' [foreach::foreach()] for examples.
#'
#' For the most part, warnings generated during training are shown as they occur
#' and are associated with a specific resample when
#' `control_bayes(verbose = TRUE)`. They are (usually) not aggregated until the
#' end of processing.
#'
#' For Bayesian optimization, parallel processing is used to estimate the
#' resampled performance values once a new candidate set of values are estimated.
#'
#' @section Initial Values:
#'
#' The results of [tune_grid()], or a previous run of [tune_bayes()] can be used
#' in the `initial` argument. `initial` can also be a positive integer. In this
#' case, a space-filling design will be used to populate a preliminary set of
#' results. For good results, the number of initial values should be more than
#' the number of parameters being optimized.
#'
#' @section Parameter Ranges and Values:
#'
#' In some cases, the tuning parameter values depend on the dimensions of the
#' data (they are said to contain [unknown][dials::unknown] values). For
#' example, `mtry` in random forest models depends on the number of predictors.
#' In such cases, the unknowns in the tuning parameter object must be determined
#' beforehand and passed to the function via the `param_info` argument.
#' [dials::finalize()] can be used to derive the data-dependent parameters.
#' Otherwise, a parameter set can be created via [dials::parameters()], and the
#' `dials` `update()` function can be used to specify the ranges or values.
#'
#' @section Performance Metrics:
#'
#' To use your own performance metrics, the [yardstick::metric_set()] function
#'  can be used to pick what should be measured for each model. If multiple
#'  metrics are desired, they can be bundled. For example, to estimate the area
#'  under the ROC curve as well as the sensitivity and specificity (under the
#'  typical probability cutoff of 0.50), the `metrics` argument could be given:
#'
#' \preformatted{
#'   metrics = metric_set(roc_auc, sens, spec)
#' }
#'
#' Each metric is calculated for each candidate model.
#'
#' If no metric set is provided, one is created:
#' \itemize{
#'   \item For regression models, the root mean squared error and coefficient
#'         of determination are computed.
#'   \item For classification, the area under the ROC curve and overall accuracy
#'         are computed.
#' }
#'
#' Note that the metrics also determine what type of predictions are estimated
#' during tuning. For example, in a classification problem, if metrics are used
#' that are all associated with hard class predictions, the classification
#' probabilities are not created.
#'
#' The out-of-sample estimates of these metrics are contained in a list column
#' called `.metrics`. This tibble contains a row for each metric and columns
#' for the value, the estimator type, and so on.
#'
#' [collect_metrics()] can be used for these objects to collapse the results
#' over the resampled (to obtain the final resampling estimates per tuning
#' parameter combination).
#'
#' @section Obtaining Predictions:
#'
#' When `control_bayes(save_pred = TRUE)`, the output tibble contains a list
#' column called `.predictions` that has the out-of-sample predictions for each
#' parameter combination in the grid and each fold (which can be very large).
#'
#' The elements of the tibble are tibbles with columns for the tuning
#' parameters, the row number from the original data object (`.row`), the
#' outcome data (with the same name(s) of the original data), and any columns
#' created by the predictions. For example, for simple regression problems, this
#' function generates a column called `.pred` and so on. As noted above, the
#' prediction columns that are returned are determined by the type of metric(s)
#' requested.
#'
#' This list column can be `unnested` using [tidyr::unnest()] or using the
#'  convenience function [collect_predictions()].
#'
#' @inheritSection tune_grid Extracting Information
#'
#' @examplesIf tune:::should_run_examples(suggests = "kernlab")
#' library(recipes)
#' library(rsample)
#' library(parsnip)
#'
#' # define resamples and minimal recipe on mtcars
#' set.seed(6735)
#' folds <- vfold_cv(mtcars, v = 5)
#'
#' car_rec <-
#'   recipe(mpg ~ ., data = mtcars) %>%
#'   step_normalize(all_predictors())
#'
#' # define an svm with parameters to tune
#' svm_mod <-
#'   svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
#'   set_engine("kernlab") %>%
#'   set_mode("regression")
#'
#' # use a space-filling design with 6 points
#' set.seed(3254)
#' svm_grid <- tune_grid(svm_mod, car_rec, folds, grid = 6)
#'
#' show_best(svm_grid, metric = "rmse")
#'
#' # use bayesian optimization to evaluate at 6 more points
#' set.seed(8241)
#' svm_bayes <- tune_bayes(svm_mod, car_rec, folds, initial = svm_grid, iter = 6)
#'
#' # note that bayesian optimization evaluated parameterizations
#' # similar to those that previously decreased rmse in svm_grid
#' show_best(svm_bayes, metric = "rmse")
#'
#' # specifying `initial` as a numeric rather than previous tuning results
#' # will result in `tune_bayes` initially evaluating an space-filling
#' # grid using `tune_grid` with `grid = initial`
#' set.seed(0239)
#' svm_init <- tune_bayes(svm_mod, car_rec, folds, initial = 6, iter = 6)
#'
#' show_best(svm_init, metric = "rmse")
#' @export
tune_bayes <- function(object, ...) {
  UseMethod("tune_bayes")
}

#' @export
tune_bayes.default <- function(object, ...) {
  msg <- paste0(
    "The first argument to [tune_bayes()] should be either ",
    "a model or workflow."
  )
  rlang::abort(msg)
}

#' @export
#' @rdname tune_bayes
tune_bayes.model_spec <- function(object,
                                  preprocessor,
                                  resamples,
                                  ...,
                                  iter = 10,
                                  param_info = NULL,
                                  metrics = NULL,
                                  objective = exp_improve(),
                                  initial = 5,
                                  control = control_bayes()) {
  if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
    rlang::abort(paste(
      "To tune a model spec, you must preprocess",
      "with a formula or recipe"
    ))
  }

  control <- parsnip::condense_control(control, control_bayes())

  wflow <- add_model(workflow(), object)

  if (is_recipe(preprocessor)) {
    wflow <- add_recipe(wflow, preprocessor)
  } else if (rlang::is_formula(preprocessor)) {
    wflow <- add_formula(wflow, preprocessor)
  }

  tune_bayes_workflow(
    wflow,
    resamples = resamples, iter = iter, param_info = param_info,
    metrics = metrics, objective = objective, initial = initial,
    control = control, ...
  )
}


#' @export
#' @rdname tune_bayes
tune_bayes.workflow <-
  function(object,
           resamples,
           ...,
           iter = 10,
           param_info = NULL,
           metrics = NULL,
           objective = exp_improve(),
           initial = 5,
           control = control_bayes()) {

    control <- parsnip::condense_control(control, control_bayes())

    res <-
      tune_bayes_workflow(
        object,
        resamples = resamples, iter = iter, param_info = param_info,
        metrics = metrics, objective = objective, initial = initial,
        control = control, ...
      )
    .stash_last_result(res)
    res
  }

tune_bayes_workflow <-
  function(object, resamples, iter = 10, param_info = NULL, metrics = NULL,
           objective = exp_improve(),
           initial = 5, control = control_bayes(), ...) {
    start_time <- proc.time()[3]

    initialize_catalog(control = control)

    check_rset(resamples)
    rset_info <- pull_rset_attributes(resamples)

    metrics <- check_metrics(metrics, object)
    metrics_data <- metrics_info(metrics)
    metrics_name <- metrics_data$.metric[1]
    maximize <- metrics_data$direction[metrics_data$.metric == metrics_name] == "maximize"

    if (is.null(param_info)) {
      param_info <- hardhat::extract_parameter_set_dials(object)
    }
    check_workflow(object, check_dials = is.null(param_info), pset = param_info)
    check_backend_options(control$backend_options)

    unsummarized <- check_initial(
      initial, param_info, object, resamples,
      metrics, control,
      checks = "bayes"
    )

    # Pull outcome names from initialization run
    outcomes <- peek_tune_results_outcomes(unsummarized)

    evalq({
    # Return whatever we have if there is a error (or execution is stopped)
    on.exit({
      cli::cli_alert_danger("Optimization stopped prematurely; returning current results.")

      out <- new_iteration_results(
        x = unsummarized,
        parameters = param_info,
        metrics = metrics,
        outcomes = outcomes,
        rset_info = rset_info,
        workflow = NULL
      )

      .stash_last_result(out)

      return(out)
    })

    # Preempt `estimate_tune_results()` error and rely
    # on `on.exit()` condition to return preliminary results
    if (is_cataclysmic(unsummarized)) {
      return()
    }

    # Get the averaged resampling stats before stripping attributes
    mean_stats <- estimate_tune_results(unsummarized)

    # Strip off `tune_results` class and drop all attributes since
    # we add on an `iteration_results` class later.
    unsummarized <- new_bare_tibble(unsummarized)

    check_time(start_time, control$time_limit)

    score_card <- initial_info(mean_stats, metrics_name, maximize)

    if (control$verbose_iter) {
      message_wrap(paste("Optimizing", metrics_name, "using", objective$label))
    }

    prev_gp_mod <- NULL

    for (i in (1:iter) + score_card$overall_iter) {
      .notes <-
        tibble::new_tibble(
          list(location = character(0), type = character(0), note = character(0)),
          nrow = 0
        )

      log_best(control, i, score_card)

      check_time(start_time, control$time_limit)

      set.seed(control$seed[1] + i)
      gp_mod <-
        .catch_and_log(
          fit_gp(
            mean_stats %>% dplyr::select(-.iter),
            pset = param_info,
            metric = metrics_name,
            control = control,
            ...
          ),
          control,
          NULL,
          "Gaussian process model",
          notes = .notes,
          catalog = FALSE
        )

      gp_mod <- check_gp_failure(gp_mod, prev_gp_mod)

      save_gp_results(gp_mod, param_info, control, i, iter)

      check_time(start_time, control$time_limit)

      set.seed(control$seed[1] + i + 1)
      candidates <-
        pred_gp(
          gp_mod, param_info,
          control = control,
          current = mean_stats %>% dplyr::select(dplyr::all_of(param_info$id))
        )

      check_time(start_time, control$time_limit)

      acq_summarizer(control, iter = i, objective = objective)

      candidates <-
        dplyr::bind_cols(
          candidates,
          stats::predict(
            objective, candidates,
            iter = i,
            maximize = maximize, score_card$best_val
          )
        )

      check_time(start_time, control$time_limit)

      check_and_log_flow(control, candidates)

      candidates <- pick_candidate(candidates, score_card, control)
      if (score_card$uncertainty >= control$uncertain) {
        score_card$uncertainty <- -1 # is updated in update_score_card() below
      }

      check_time(start_time, control$time_limit)

      param_msg(control, candidates)
      set.seed(control$seed[1] + i + 2)
      tmp_res <-
        more_results(
          object,
          resamples = resamples,
          candidates = candidates,
          metrics = metrics,
          control = control,
          param_info = param_info
        )

      check_time(start_time, control$time_limit)

      all_bad <- is_cataclysmic(tmp_res)

      if (!inherits(tmp_res, "try-error") & !all_bad) {
        tmp_res[[".metrics"]] <- purrr::map(
          tmp_res[[".metrics"]],
          ~ dplyr::mutate(., .config = paste0("Iter", i))
        )
        if (control$save_pred) {
          tmp_res[[".predictions"]] <- purrr::map(
            tmp_res[[".predictions"]],
            ~ dplyr::mutate(., .config = paste0("Iter", i))
          )
        }
        unsummarized <- dplyr::bind_rows(unsummarized, tmp_res %>% mutate(.iter = i))
        rs_estimate <- estimate_tune_results(tmp_res)
        mean_stats <- dplyr::bind_rows(mean_stats, rs_estimate %>% dplyr::mutate(.iter = i))
        score_card <- update_score_card(score_card, i, tmp_res)
        log_progress(control, x = mean_stats, maximize = maximize, objective = metrics_name)
      } else {
        if (all_bad) {
          tune_log(control, split = NULL, task = "All models failed", type = "danger")
        }
        score_card$last_impr <- score_card$last_impr + 1
      }

      if (score_card$last_impr + 1 > control$no_improve) {
        cli::cli_alert_warning(
          "No improvement for {control$no_improve} iterations; returning current results."
        )
        break
      }
      prev_gp_mod <- gp_mod
      check_time(start_time, control$time_limit)
    }

    workflow_output <- set_workflow(object, control)

    # Reset `on.exit()` hook
    on.exit()

    res <-
      new_iteration_results(
        x = unsummarized,
        parameters = param_info,
        metrics = metrics,
        outcomes = outcomes,
        rset_info = rset_info,
        workflow = workflow_output
      )

    .stash_last_result(res)
    res
    }) # end of evalq() call
  }

create_initial_set <- function(param, n = NULL, checks) {
  check_param_objects(param)
  if (is.null(n)) {
    n <- nrow(param) + 1
  }
  if (any(checks == "bayes")) {
    check_bayes_initial_size(nrow(param), n)
  }
  dials::grid_latin_hypercube(param, size = n)
}


# ------------------------------------------------------------------------------

#' @export
#' @keywords internal
#' @rdname empty_ellipses
#' @param pset A `parameters` object.
#' @param as_matrix A logical for the return type.
encode_set <- function(x, pset, as_matrix = FALSE, ...) {
  # change the numeric variables to the transformed scale (if any)
  has_trans <- purrr::map_lgl(pset$object, ~ !is.null(.x$trans))
  if (any(has_trans)) {
    idx <- which(has_trans)
    for (i in idx) {
      x[[pset$id[i]]] <-
        dials::value_transform(pset$object[[i]], x[[pset$id[i]]])
    }
  }

  is_quant <- purrr::map_lgl(pset$object, inherits, "quant_param")
  # Convert all data to the [0, 1] scale based on their possible range (not on
  # their observed range)
  if (any(is_quant)) {
    new_vals <- purrr::map2(pset$object[is_quant], x[, is_quant], encode_unit, direction = "forward")
    names(new_vals) <- names(x)[is_quant]
    new_vals <- tibble::as_tibble(new_vals)
    x[, is_quant] <- new_vals
  }

  # Ensure that the right levels are used to create dummy variables
  if (any(!is_quant)) {
    for (i in which(!is_quant)) {
      x[[i]] <- factor(x[[i]], levels = pset$object[[i]]$values)
    }
  }

  if (as_matrix) {
    x <- stats::model.matrix(~ . + 0, data = x)
  }
  x
}

fit_gp <- function(dat, pset, metric, control, ...) {
  dat <-
    dat %>%
    dplyr::filter(.metric == metric) %>%
    check_gp_data() %>%
    dplyr::select(dplyr::all_of(pset$id), mean)

  x <- encode_set(dat %>% dplyr::select(-mean), pset, as_matrix = TRUE)

  if (nrow(x) <= ncol(x) + 1 && nrow(x) > 0) {
    msg <-
      paste(
        "The Gaussian process model is being fit using ", ncol(x),
        "features but only has", nrow(x), "data points to do so. This may cause",
        "errors or a poor model fit."
      )
    message_wrap(msg, prefix = "!", color_text = get_tune_colors()$message$warning)
  }

  opts <- list(...)
  if (any(names(opts) == "trace") && opts$trace) {
    gp_fit <- GPfit::GP_fit(X = x, Y = dat$mean, ...)
  } else {
    tmp_output <- utils::capture.output(
      gp_fit <- GPfit::GP_fit(X = x, Y = dat$mean, ...)
    )
  }

  gp_fit
}

pred_gp <- function(object, pset, size = 5000, current = NULL, control) {
  pred_grid <-
    dials::grid_latin_hypercube(pset, size = size) %>%
    dplyr::distinct()

  if (!is.null(current)) {
    pred_grid <-
      pred_grid %>%
      dplyr::anti_join(current, by = pset$id)
  }

  if (inherits(object, "try-error") | nrow(pred_grid) == 0) {
    if (nrow(pred_grid) == 0) {
      msg <- "No remaining candidate models"
    } else {
      msg <- "An error occurred when creating candidates parameters: "
      msg <- paste(msg, as.character(object))
    }
    tune_log(control, split = NULL, task = msg, type = "warning")
    return(pred_grid %>% dplyr::mutate(.mean = NA_real_, .sd = NA_real_))
  }

  tune_log(
    control,
    split = NULL,
    task = paste("Generating", nrow(pred_grid), "candidates"),
    type = "info",
    catalog = FALSE
  )

  x <- encode_set(pred_grid, pset, as_matrix = TRUE)
  gp_pred <- predict(object, x)

  tune_log(control, split = NULL, task = "Predicted candidates", type = "info", catalog = FALSE)

  pred_grid %>%
    dplyr::mutate(.mean = gp_pred$Y_hat, .sd = sqrt(gp_pred$MSE))
}


pick_candidate <- function(results, info, control) {
  if (info$uncertainty < control$uncertain) {
    results <- results %>%
      dplyr::arrange(dplyr::desc(objective)) %>%
      dplyr::slice(1)
  } else {
    if (control$verbose_iter) {
      msg <- paste(blue(cli::symbol$circle_question_mark), "Uncertainty sample")
      message(msg)
    }
    results <-
      results %>%
      dplyr::arrange(dplyr::desc(.sd)) %>%
      dplyr::slice(1:floor(.1 * nrow(results))) %>%
      dplyr::sample_n(1)
  }
  results
}

update_score_card <- function(info, iter, results, control) {
  current_val <-
    results %>%
    estimate_tune_results() %>%
    dplyr::filter(.metric == info$metrics) %>%
    dplyr::pull(mean)

  if (info$max) {
    is_better <- current_val > info$best_val
  } else {
    is_better <- current_val < info$best_val
  }

  if (!is.na(is_better) & is_better) {
    info$last_impr <- 0
    info$best_val <- current_val
    info$best_iter <- iter
    info$uncertainty <- 0
  } else {
    info$last_impr <- info$last_impr + 1
    info$uncertainty <- info$uncertainty + 1
  }
  info
}




# ------------------------------------------------------------------------------


# save metrics_name and maximize to simplify!!!!!!!!!!!!!!!
initial_info <- function(stats, metrics, maximize) {
  best_res <-
    stats %>%
    dplyr::filter(.metric == metrics) %>%
    dplyr::filter(!is.na(mean))

  if (maximize) {
    best_res <-
      best_res %>%
      dplyr::arrange(desc(mean)) %>%
      slice(1)
  } else {
    best_res <-
      best_res %>%
      dplyr::arrange(mean) %>%
      slice(1)
  }
  best_val <- best_res$mean[1]
  best_iter <- best_res$.iter[1]
  last_impr <- 0
  overall_iter <- max(stats$.iter)

  # outputs:

  list(
    best_val = best_val,
    best_iter = best_iter,
    last_impr = last_impr,
    uncertainty = 0,
    overall_iter = overall_iter,
    metrics = metrics,
    max = maximize
  )
}


# ------------------------------------------------------------------------------


more_results <- function(object, resamples, candidates, metrics, control, param_info) {
  tune_log(control, split = NULL, task = "Estimating performance", type = "info")

  candidates <- candidates[, !(names(candidates) %in% c(".mean", ".sd", "objective"))]
  p_chr <- paste0(names(candidates), "=", format(as.data.frame(candidates), digits = 3))

  tmp_res <-
    try(
      tune_grid(
        object,
        resamples = resamples,
        param_info = param_info,
        grid = candidates,
        metrics = metrics,
        control = control
      ),
      silent = TRUE
    )

  if (inherits(tmp_res, "try-error")) {
    tune_log(
      control,
      split = NULL, task = "Couldn't estimate performance",
      type = "danger"
    )
  } else {
    all_bad <- is_cataclysmic(tmp_res)
    if (all_bad) {
      p_chr <- glue::glue_collapse(p_chr, width = options()$width - 28, sep = ", ")
      msg <- paste("All models failed for:", p_chr)
      tune_log(control, split = NULL, task = msg, type = "danger")
      tmp_res <- simpleError(msg)
    } else {
      tune_log(
        control,
        split = NULL, task = "Estimating performance",
        type = "success"
      )
    }
  }
  tmp_res
}


is_cataclysmic <- function(x) {
  is_err <- purrr::map_lgl(x$.metrics, inherits, c("simpleError", "error"))
  if (any(!is_err)) {
    is_good <- purrr::map_lgl(
      x$.metrics[!is_err],
      ~ tibble::is_tibble(.x) && nrow(.x) > 0
    )
    is_err[!is_err] <- !is_good
  }
  all(is_err)
}



#' @export
#' @keywords internal
#' @rdname empty_ellipses
#' @param origin The calculation start time.
#' @param limit The allowable time (in minutes).
check_time <- function(origin, limit) {
  if (is.na(limit)) {
    return(invisible(NULL))
  }
  now_time <- proc.time()[3]
  if (now_time - origin >= limit * 60) {
    rlang::abort(paste("The time limit of", limit, "minutes has been reached."))
  }
  invisible(NULL)
}

# May be better to completely refactor things to a high-level call then use
# base's setTimeLimit().

# Make sure that rset object attributes are kept once joined
reup_rs <- function(resamples, res) {
  sort_cols <- grep("^id", names(resamples), value = TRUE)
  if (any(names(res) == ".iter")) {
    sort_cols <- c(".iter", sort_cols)
  }
  res <- dplyr::arrange(res, !!!syms(sort_cols))
  att <- attributes(res)
  rsample_att <- attributes(resamples)
  for (i in names(rsample_att)) {
    if (!any(names(att) == i)) {
      attr(res, i) <- rsample_att[[i]]
    }
  }

  class(res) <- unique(c("tune_results", class(res)))
  res
}

## -----------------------------------------------------------------------------

save_gp_results <- function(x, pset, ctrl, i, iter) {
  if (!ctrl$save_gp_scoring) {
    return(invisible(NULL))
  }

  nm <- recipes::names0(iter, "gp_candidates_")[i]
  file_name <- paste0(nm, ".RData")
  res <- try(save(x, pset, i, file = file.path(tempdir(), file_name)), silent = TRUE)
  if (inherits(res, "try-error")) {
    rlang::warn(paste("Could not save GP results:", as.character(res)))
  }
  invisible(res)
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Aug. 24, 2023, 1:09 a.m.