R/int_pctl.R

Defines functions thin_time filter_predictions_by_eval_time get_configs boot_metrics int_pctl_iter int_pctl.tune_results

Documented in int_pctl.tune_results

#' Bootstrap confidence intervals for performance metrics
#'
#' Using out-of-sample predictions, the bootstrap is used to create percentile
#' confidence intervals.
#' @inheritParams collect_predictions
#' @inheritParams rlang::args_dots_empty
#' @inheritParams rsample::int_pctl
#' @inheritParams rsample::bootstraps
#' @param .data A object with class `tune_results` where the `save_pred = TRUE`
#' option was used in the control function.
#' @param metrics A [yardstick::metric_set()]. By default, it uses the same
#' metrics as the original object.
#' @param allow_par A logical to allow parallel processing (if a parallel
#' backend is registered).
#' @param event_level A single string. Either `"first"` or `"second"` to specify
#' which level of truth to consider as the "event".
#' @param eval_time A vector of evaluation times for censored regression models.
#' `NULL` is appropriate otherwise. If `NULL` is used with censored models, a
#' evaluation time is selected, and a warning is issued.
#' @param keep_replicates A logic for saving the individual estimates from each
#' bootstrap sample (as a list column called `.values`).
#' @return A tibble of metrics with additional columns for `.lower` and
#' `.upper` (and potentially, `.values`).
#' @details
#' For each model configuration (if any), this function takes bootstrap samples
#' of the out-of-sample predicted values. For each bootstrap sample, the metrics
#' are computed and these are used to compute confidence intervals.
#' See [rsample::int_pctl()] and the references therein for more details.
#'
#' Note that the `.estimate` column is likely to be different from the results
#' given by [collect_metrics()] since a different estimator is used. Since
#' random numbers are used in sampling, set the random number seed prior to
#' running this function.
#'
#' The number of bootstrap samples should be large to have reliable intervals.
#' The defaults reflect the fewest samples that should be used.
#'
#' The computations for each configuration can be extensive. To increase
#' computational efficiency parallel processing can be used. The \pkg{future}
#' package is used here. To execute the resampling iterations in parallel,
#' specify a [plan][future::plan] with future first. The `allow_par` argument
#' can be used to avoid parallelism.
#'
#' Also, if a censored regression model used numerous evaluation times, the
#' computations can take a long time unless the times are filtered with the
#' `eval_time` argument.
#' @seealso [rsample::int_pctl()]
#' @references Davison, A., & Hinkley, D. (1997). _Bootstrap Methods and their
#'  Application_. Cambridge: Cambridge University Press.
#'  doi:10.1017/CBO9780511802843
#' @examplesIf !tune:::is_cran_check()
#' if (rlang::is_installed("modeldata")) {
#'   data(Sacramento, package = "modeldata")
#'   library(rsample)
#'   library(parsnip)
#'
#'   set.seed(13)
#'   sac_rs <- vfold_cv(Sacramento)
#'
#'   lm_res <-
#'     linear_reg() |>
#'     fit_resamples(
#'       log10(price) ~ beds + baths + sqft + type + latitude + longitude,
#'       resamples = sac_rs,
#'       control = control_resamples(save_pred = TRUE)
#'     )
#'
#'   set.seed(31)
#'   int_pctl(lm_res)
#' }
#' @export
int_pctl.tune_results <- function(
  .data,
  metrics = NULL,
  eval_time = NULL,
  times = 1001,
  parameters = NULL,
  alpha = 0.05,
  allow_par = TRUE,
  event_level = "first",
  keep_replicates = FALSE,
  ...
) {
  rlang::check_dots_empty()

  if (is.null(metrics)) {
    metrics <- .get_tune_metrics(.data)
  } else {
    if (!inherits(metrics, "metric_set")) {
      cli::cli_abort(
        "{.arg metrics} should be a metric set as generated by {.fun yardstick::metric_set}."
      )
    }
  }

  if (is.null(eval_time)) {
    eval_time <- .get_tune_eval_times(.data)
    eval_time <- maybe_choose_eval_time(.data, metrics, eval_time)
  } else {
    eval_time <- unique(eval_time)
    check_eval_time_in_tune_results(.data, eval_time)
    # Are there at least a minimal number of evaluation times?
    check_enough_eval_times(eval_time, metrics)
  }

  .data$.predictions <- filter_predictions_by_eval_time(
    .data$.predictions,
    eval_time
  )

  y_nm <- outcome_names(.data)

  config_keys <- get_configs(
    .data,
    parameters = parameters,
    as_list = FALSE
  )

  metrics_info <- metrics_info(metrics)
  res <-
    int_pctl_iter(
      config_keys,
      .data,
      metrics = metrics,
      times = times,
      allow_par = allow_par,
      event_level = event_level,
      alpha = alpha,
      metrics_info = metrics_info,
      keep_replicates = keep_replicates
    )

  res
}

# ------------------------------------------------------------------------------

int_pctl_iter <- function(
  config,
  x,
  metrics,
  times,
  allow_par,
  event_level,
  alpha,
  metrics_info,
  keep_replicates
) {
  y_nm <- outcome_names(x)
  param_names <- .get_tune_parameter_names(x)

  preds <- collect_predictions(x, summarize = TRUE, parameters = config)

  rs <- rsample::bootstraps(preds, times = times, strata = ".config")

  strategy <-
    choose_framework(
      object = NULL,
      control = list(allow_par = allow_par),
      verbose = FALSE,
      default = "mirai"
    )

  cl <- pctl_call(strategy)
  rs$.metrics <- rlang::eval_bare(cl)

  res <- rsample::int_pctl(rs, .metrics, alpha = alpha)
  has_iter <- any(names(res) == ".iter")

  nest_cols <- c(".config", "term")
  if (has_iter) {
    nest_cols <- c(".iter", nest_cols)
  }

  if (keep_replicates) {
    rs_est <- purrr::list_rbind(rs$.metrics)

    rs_est <-
      vctrs::vec_split(
        x = rs_est[setdiff(colnames(rs_est), nest_cols)],
        by = rs_est[nest_cols]
      )

    rs_est <- vctrs::vec_cbind(
      rs_est$key,
      tibble::new_tibble(list(.values = rs_est$val))
    )

    res <- dplyr::full_join(res, rs_est, by = nest_cols)
  }

  res$.metric <- res$term
  res$term <- NULL
  res$.estimator <- "bootstrap"
  res$.alpha <- NULL
  res$.method <- NULL

  by_cols <- nest_cols[nest_cols != "term"]
  res <- dplyr::inner_join(res, config, by = by_cols)

  sort_cols <- rlang::syms(c(".metric", by_cols))
  res <- res |> dplyr::arrange(!!!sort_cols)
  first_cols <- c(
    param_names,
    ".metric",
    ".estimator",
    ".lower",
    ".estimate",
    ".upper",
    ".config"
  )
  if (has_iter) {
    first_cols <- c(first_cols, ".iter")
  }

  res[, c(first_cols, setdiff(names(res), first_cols))]
}

boot_metrics <- function(
  split,
  y,
  metrics,
  param_names = NULL,
  event_level,
  metrics_info,
  configs
) {
  dat <- rsample::analysis(split)

  res <-
    .estimate_metrics(
      dat,
      metric = metrics,
      param_names = param_names,
      outcome_name = y,
      event_level = event_level,
      metrics_info = metrics_info
    )

  if (length(param_names) > 0) {
    res <- dplyr::inner_join(res, configs, by = param_names)
  } else {
    res <- vctrs::vec_cbind(res, configs)
  }

  res$term <- res$.metric
  res$.metric <- NULL
  res$estimate <- res$.estimate
  res$.estimate <- NULL
  res$.estimator <- NULL
  res
}

# ------------------------------------------------------------------------------

get_configs <- function(x, parameters = NULL, as_list = TRUE) {
  param <- .get_tune_parameter_names(x)
  config_cols <- c(".config", ".iter", param)
  config_keys <-
    collect_metrics(x, summarize = FALSE) |>
    dplyr::distinct(dplyr::pick(dplyr::any_of(config_cols)))
  if (!is.null(parameters)) {
    merge_cols <- intersect(names(config_keys), names(parameters))
    config_keys <- dplyr::inner_join(config_keys, parameters, by = merge_cols)
  }
  if (as_list) {
    config_keys <- vctrs::vec_chop(config_keys, as.list(1:nrow(config_keys)))
  }
  config_keys
}

filter_predictions_by_eval_time <- function(x, eval_time = NULL) {
  if (is.null(eval_time)) {
    return(x)
  }
  purrr::map(x, thin_time, times = eval_time)
}

thin_time <- function(x, times) {
  subset_time <- function(x, times) {
    x[x$.eval_time %in% times, ]
  }

  x$.pred <- purrr::map(x$.pred, subset_time, times = times)
  x
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.