R/int_pctl.R

Defines functions thin_time filter_predictions_by_eval_time comp_metrics get_configs int_pctl_surv fake_term boostrap_metrics_by_config get_int_p_operator int_pctl.tune_results

Documented in int_pctl.tune_results

#' Bootstrap confidence intervals for performance metrics
#'
#' Using out-of-sample predictions, the bootstrap is used to create percentile
#' confidence intervals.
#' @inheritParams collect_predictions
#' @inheritParams rlang::args_dots_empty
#' @inheritParams rsample::int_pctl
#' @inheritParams rsample::bootstraps
#' @param .data A object with class `tune_results` where the `save_pred = TRUE`
#' option was used in the control function.
#' @param metrics A [yardstick::metric_set()]. By default, it uses the same
#' metrics as the original object.
#' @param allow_par A logical to allow parallel processing (if a parallel
#' backend is registered).
#' @param event_level A single string. Either `"first"` or `"second"` to specify
#' which level of truth to consider as the "event".
#' @param eval_time A vector of evaluation times for censored regression models.
#' `NULL` is appropriate otherwise. If `NULL` is used with censored models, a
#' evaluation time is selected, and a warning is issued.
#' @return A tibble of metrics with additional columns for `.lower` and
#' `.upper`.
#' @details
#' For each model configuration (if any), this function takes bootstrap samples
#' of the out-of-sample predicted values. For each bootstrap sample, the metrics
#' are computed and these are used to compute confidence intervals.
#' See [rsample::int_pctl()] and the references therein for more details.
#'
#' Note that the `.estimate` column is likely to be different from the results
#' given by [collect_metrics()] since a different estimator is used. Since
#' random numbers are used in sampling, set the random number seed prior to
#' running this function.
#'
#' The number of bootstrap samples should be large to have reliable intervals.
#' The defaults reflect the fewest samples that should be used.
#'
#' The computations for each configuration can be extensive. To increase
#' computational efficiency parallel processing can be used. The \pkg{future}
#' package is used here. To execute the resampling iterations in parallel,
#' specify a [plan][future::plan] with future first. The `allow_par` argument
#' can be used to avoid parallelism.
#'
#' Also, if a censored regression model used numerous evaluation times, the
#' computations can take a long time unless the times are filtered with the
#' `eval_time` argument.
#' @seealso [rsample::int_pctl()]
#' @references Davison, A., & Hinkley, D. (1997). _Bootstrap Methods and their
#'  Application_. Cambridge: Cambridge University Press.
#'  doi:10.1017/CBO9780511802843
#' @examplesIf !tune:::is_cran_check() & tune:::should_run_examples("modeldata")
#' data(Sacramento, package = "modeldata")
#' library(rsample)
#' library(parsnip)
#'
#' set.seed(13)
#' sac_rs <- vfold_cv(Sacramento)
#'
#' lm_res <-
#'   linear_reg() %>%
#'   fit_resamples(
#'     log10(price) ~ beds + baths + sqft + type + latitude + longitude,
#'     resamples = sac_rs,
#'     control = control_resamples(save_pred = TRUE)
#'   )
#'
#' set.seed(31)
#' int_pctl(lm_res)
#' @export
int_pctl.tune_results <- function(.data, metrics = NULL, eval_time = NULL,
                                  times = 1001, parameters = NULL,
                                  alpha = 0.05, allow_par = TRUE,
                                  event_level = "first", ...) {

  rlang::check_dots_empty()

  # performance speedup, enables tidymodels/yardstick#428
  initialize_catalog(control_resamples(allow_par = TRUE))

  if (is.null(metrics)) {
    metrics <- .get_tune_metrics(.data)
  } else {
    if (!inherits(metrics, "metric_set")) {
      cli::cli_abort("{.arg metrics} should be a metric set as generated by {.fun yardstick::metric_set}.")
    }
  }

  if (is.null(eval_time)) {
    eval_time <- .get_tune_eval_times(.data)
    eval_time <- maybe_choose_eval_time(.data, metrics, eval_time)
  } else {
    eval_time <- unique(eval_time)
    check_eval_time_in_tune_results(.data, eval_time)
    # Are there at least a minimal number of evaluation times?
    check_enough_eval_times(eval_time, metrics)
  }

  .data$.predictions <- filter_predictions_by_eval_time(.data$.predictions, eval_time)

  y_nm <- outcome_names(.data)

  config_keys <- get_configs(.data, parameters = parameters)
  p <- length(config_keys)

  #  TODO Changes in https://github.com/tidymodels/rsample/pull/465
  #  will effect how these computations are done since they will
  #  compute intervals for `terms` as well as any columns that begin
  #  with a period. This will simply the code considerably for
  #  survival and non-survival models. We will make this version
  #  compatible with the future rsample version (but will factor
  #  this code later).
  metrics_info <- metrics_info(metrics)

  res <-
    purrr::map2(
      config_keys, sample.int(10000, p),
      ~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par,
                                   event_level, alpha, metrics_info)
    ) %>%
    purrr::list_rbind() %>%
    dplyr::arrange(.config, .metric)
  dplyr::as_tibble(res)
}

get_int_p_operator <- function(allow = TRUE) {
  is_par <- foreach::getDoParWorkers() > 1 || future::nbrOfWorkers() > 1
  if (allow && is_par) {
    res <- switch(
      # note some backends can return +Inf
      min(future::nbrOfWorkers(), 2),
      list(op = foreach::`%dopar%`, is_future = FALSE),
      list(op = doFuture::`%dofuture%`, is_future = TRUE)
    )

    if (!res[["is_future"]]) {
      warn_foreach_deprecation()
    }
  } else {
    res <- list(foreach::`%do%`, is_future = FALSE)
  }
  res
}

boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par,
                                       event_level, alpha, metrics_info) {
  y_nm <- outcome_names(x)
  preds <- collect_predictions(x, summarize = TRUE, parameters = config)

  set.seed(seed)
  rs <- rsample::bootstraps(preds, times = times)

  do_op <- get_int_p_operator(allow_par)
  `%op%` <- do_op[[1]]
  is_future <- do_op[[2]]

  # Rather than generating them programmatically, write each `foreach()`
  # call out since `foreach()` `substitute()`s its dots. Note that
  # doFuture will error when passed `.packages`.
  if (is_future) {
    for_each <-
      foreach::foreach(
        i = seq_len(nrow(rs)),
        .options.future = list(seed = NULL, packages = c("tune", "rsample"))
      )
  } else {
    for_each <-
      foreach::foreach(
        i = seq_len(nrow(rs)),
        .packages = c("tune", "rsample"),
        .errorhandling = "pass"
      )
  }

  rs$.metrics <-
    for_each %op% {
      asNamespace("tune")$comp_metrics(
        rs$splits[[i]],
        y_nm,
        metrics,
        event_level,
        metrics_info
      )
    }
  if (any(grepl("survival", .get_tune_metric_names(x)))) {
    # compute by evaluation time
    res <- int_pctl_surv(rs, allow_par, alpha)
  } else {
    res <- rsample::int_pctl(rs, .metrics, alpha = alpha)
  }

  res$.metric <- res$term
  res$term <- NULL
  res$.estimator <- "bootstrap"
  res$.alpha <- NULL
  res$.method <- NULL
  res <- cbind(res, config)
  first_cols <- c(".metric", ".estimator")
  res[c(first_cols, setdiff(names(res), first_cols))]
}

fake_term <- function(x) {
  x$term <- paste(x$term, format(1:nrow(x)))
  x
}

# tests in extratests
# nocov start
int_pctl_surv <- function(x, allow_par, alpha) {

  # int_pctl() expects terms to be unique. For (many) survival models, the
  # metrics are a combination of the metric name and the evaluation time.
  # We'll make a phony term value, run int_pctl(), then merge the original values
  # back in.
  met_key <- x$.metrics[[1]]
  met_key$estimate <- NULL
  met_key$old_term <- met_key$term
  met_key$order <- 1:nrow(met_key)
  met_key <- fake_term(met_key)

  x$.metrics <- purrr::map(x$.metrics, ~ fake_term(.x))
  res <- rsample::int_pctl(x, .metrics, alpha = alpha)

  merge_keys <- c("term", grep("^\\.", names(res), value = TRUE))
  merge_keys <- intersect(merge_keys, names(met_key))

  res <- res %>%
    dplyr::full_join(met_key, by = merge_keys) %>%
    dplyr::arrange(order) %>%
    dplyr::select(-term, -order) %>%
    dplyr::rename(term = old_term) %>%
    dplyr::relocate(term, dplyr::any_of(".eval_time"))
}
# nocov end

# ------------------------------------------------------------------------------

get_configs <- function(x, parameters = NULL, as_list = TRUE) {
  param <- .get_tune_parameter_names(x)
  config_cols <- c(".config", ".iter", param)
  config_keys <-
    collect_metrics(x, summarize = FALSE) %>%
    dplyr::distinct(dplyr::pick(dplyr::any_of(config_cols)))
  if (!is.null(parameters)) {
    merge_cols <- intersect(names(config_keys), names(parameters))
    config_keys <- dplyr::inner_join(config_keys, parameters, by = merge_cols)
  }
  if (as_list) {
    config_keys <- vctrs::vec_chop(config_keys, as.list(1:nrow(config_keys)))
  }
  config_keys
}

# Compute metrics for a specific configuration
comp_metrics <- function(split,
                         y,
                         metrics,
                         event_level,
                         metrics_info) {
  dat <- rsample::analysis(split)

  res <-
    .estimate_metrics(
      dat,
      metric = metrics,
      param_names = NULL,
      outcome_name = y,
      event_level = event_level,
      metrics_info = metrics_info
    )

  res$term <- res$.metric
  res$.metric <- NULL
  res$estimate <- res$.estimate
  res$.estimate <- NULL
  res$.estimator <- NULL
  res
}

# ------------------------------------------------------------------------------


filter_predictions_by_eval_time <- function(x, eval_time = NULL) {
  if (is.null(eval_time)) {
    return(x)
  }
  purrr::map(x, thin_time, times = eval_time)
}

thin_time <- function(x, times) {
  subset_time <- function(x, times) {
    x[x$.eval_time %in% times,]
  }

  x$.pred <- purrr::map(x$.pred, subset_time, times = times)
  x
}
tidymodels/tune documentation built on April 21, 2024, 5:44 a.m.