#' Prediction intervals via conformal inference
#' Nonparametric prediction intervals can be computed for fitted workflow
#' objects using the conformal inference method described by Lei _at al_ (2018).
#' @param object A fitted [workflows::workflow()] object.
#' @param control A control object from [control_conformal_full()] with the
#' numeric minutiae.
#' @param ... Not currently used.
#' @param train_data A data frame with the _original predictor data_ used to
#' create the fitted workflow (predictors and outcomes). If the workflow does
#' not contain these values, pass them here. If the workflow used a recipe, this
#' should be the data that were inputs to the recipe (and not the product of a
#' recipe).
#' @return An object of class `"int_conformal_full"` containing the information
#' to create intervals (which includes the training set data). The `predict()`
#' method is used to produce the intervals.
#' @details
#' This function implements what is usually called "full conformal inference"
#' (see Algorithm 1 in Lei _et al_ (2018)) since it uses the entire training
#' set to compute the intervals.
#' This function prepares the objects for the computations. The [predict()]
#' method computes the intervals for new data.
#' For a given new_data observation, the predictors are appended to the original
#' training set. Then, different "trial" values of the outcome are substituted
#' in for that observation's outcome and the model is re-fit. From each model,
#' the residual associated with the trial value is compared to a quantile of the
#' distribution of the other residuals. Usually the absolute values of the
#' residuals are used. Once the residual of a trial value exceeds the
#' distributional quantile, the value is one of the bounds.
#' The literature proposed using a grid search of trial values to find the two
#' points that correspond to the prediction intervals. To use this approach,
#' set `method = "grid"` in [control_conformal_full()]. However, the default method
#' `"search` uses two different one-dimensional iterative searches on either
#' side of the predicted value to find values that correspond to the prediction intervals.

#' For medium to large data sets, the iterative search method is likely to
#' generate slightly smaller intervals. For small training sets, grid search
#' is more likely to have somewhat smaller intervals (and will be more stable).
#' Otherwise, the iterative search method is more precise and several folds
#' faster.
#' To determine a range of possible values of the intervals, used by both methods,
#' the initial set of training set residuals are modeled using a Gamma generalized
#' linear model with a log link (see the reference by Aitkin below). For a new
#' sample, the absolute size of the residual is estimated and a multiple of
#' this value is computed as an initial guess of the search boundaries.
#' @includeRmd man/rmd/parallel_intervals.Rmd details
#' @seealso [predict.int_conformal_full()]
#' @references
#' Jing Lei, Max G'Sell, Alessandro Rinaldo, Ryan J. Tibshirani and Larry
#' Wasserman (2018) Distribution-Free Predictive Inference for Regression,
#' _Journal of the American Statistical Association_, 113:523, 1094-1111
#' Murray Aitkin, Modelling Variance Heterogeneity in Normal Regression Using
#' GLIM, _Journal of the Royal Statistical Society Series C: Applied Statistics_,
#' Volume 36, Issue 3, November 1987, Pages 332–339.
#' @export
int_conformal_full <- function(object, ...) {

#' @export
#' @rdname int_conformal_full
int_conformal_full.default <- function(object, ...) {
  rlang::abort("No known 'int_conformal_full' methods for this type of object.")

#' @export
#' @rdname int_conformal_full
int_conformal_full.workflow <-
  function(object, train_data, ..., control = control_conformal_full()) {
    check_data(train_data, object)

    # --------------------------------------------------------------------------
    # check req packages

    pkgs <- required_pkgs(object)
    pkgs <- unique(c(pkgs, "workflows", "parsnip", "probably", "mgcv", control$required_pkgs))
    control$required_pkgs <- pkgs

    # --------------------------------------------------------------------------
    # We need to set the potential range that encompasses the _possible_ values
    # of the prediction interval. This is done on a sample-by-sample basis using
    # a variance model on the training set residuals.

    var_gam <- var_model(object, train_data)

    # ------------------------------------------------------------------------------
    # save results

    new_conf_infer(object, train_data, var_gam, control)

#' @export
print.int_conformal_full <- function(x, ...) {
  cat("Conformal inference\n")

  cat("preprocessor:", .get_pre_type(x$wflow), "\n")
  cat("model:", .get_fit_type(x$wflow), "\n")
  cat("training set size:", format(nrow(x$training), big.mark = ","), "\n\n")

  cat("Use `predict(object, new_data, level)` to compute prediction intervals\n")

# ------------------------------------------------------------------------------

#' Prediction intervals from conformal methods
#' @param object An object produced by [predict.int_conformal_full()].
#' @param new_data A data frame of predictors.
#' @param level The confidence level for the intervals.
#' @param ... Not currently used.
#' @return A tibble with columns `.pred_lower` and `.pred_upper`. If
#' the computations for the prediction bound fail, a missing value is used. For
#' objects produced by [int_conformal_cv()], an additional `.pred` column
#' is  also returned (see Details below).
#' @seealso [int_conformal_full()], [int_conformal_cv()]
#' @details
#' For the CV+. estimator produced by [int_conformal_cv()], the intervals
#' are centered around the mean of the predictions produced by the
#' resample-specific model. For example, with 10-fold cross-validation, `.pred`
#' is the average of the predictions from the 10 models produced by each fold.
#' This may differ from the prediction generated from a model fit that was
#' trained on the entire training set, especially if the training sets are
#' small.
#' @export
predict.int_conformal_full <- function(object, new_data, level = 0.95, ...) {
  check_data(new_data, object$wflow)

  new_pred <- setup_new_data(object, new_data, object$control$var_multiplier)

  # ------------------------------------------------------------------------------
  # split new_data for row-wise processing

  new_nest <-
    dplyr::bind_cols(new_data, new_pred) %>%
    dplyr::mutate(.row = dplyr::row_number()) %>%
    dplyr::group_by(.row) %>%

  # ------------------------------------------------------------------------------
  # compute intervals

  if (object$control$method == "grid") {
    res <- grid_all(new_nest$data, object$wflow, object$training, level, object$control)
  } else {
    res <- optimize_all(new_nest$data, object$wflow, object$training, level, object$control)
  if (object$control$progress) {

# ------------------------------------------------------------------------------

new_conf_infer <- function(wflow, .data, variance, control) {
  res <-
      wflow = wflow,
      training = .data,
      var_model = variance,
      control = control
  class(res) <- c("conformal_reg_full", "int_conformal_full")

check_data <- function(.data, wflow) {
  hardhat::scream(.data, wflow$blueprint$ptypes$predictors)

get_outcome_name <- function(x) {

get_mode <- function(x) {
  x %>%
    hardhat::extract_fit_parsnip() %>%
    purrr::pluck("spec") %>%

check_workflow <- function(x, call = rlang::caller_env()) {
  if (!workflows::is_trained_workflow(x)) {
    rlang::abort("'object' should be a fitted workflow object.", call = call)
  if (get_mode(x) != "regression") {
    rlang::abort("'object' should be a regression model.", call = call)

# ------------------------------------------------------------------------------
# Get a rough estimate of the local variance for each new sample being
# predicted. Aitkin (1987) gives a framework where, assuming normality of
# residuals, the squared residuals can be modeled using Gamma errors and the
# log link.
# This most likely underestimates the local variance so use a multiplier on the
# variance prediction to make the possible range of the bounds really wide.

var_model <- function(object, train_data, call = caller_env()) {
  y_name <- get_outcome_name(object)

  train_res <- predict(object, train_data)
  train_res$resid <- train_data[[y_name]] - train_res$.pred
  train_res$sq <- train_res$resid^2

  # ------------------------------------------------------------------------------
  # Get ballpark estimate of range of trial values for new data

  # Use squared residuals as outcome in a gamma model to predict the standard
  # deviation at a given prediction.
  var_mod <-
      mgcv::gam(sq ~ s(.pred),
        data = train_res,
        family = stats::Gamma(link = "log")
      silent = TRUE

  if (inherits(var_mod, "try-error")) {
    msg <- c(
      "The model to estimate the possible interval length failed with the following message:",
      "i" = conditionMessage(attr(var_mod, "condition"))
    rlang::abort(msg, call = call)


setup_new_data <- function(object, new_data, multiplier = 10) {
  new_pred <- predict(object$wflow, new_data)

  var_pred <- as.vector(predict(object$var_mod, new_pred, type = "response"))
  # convert back to original units
  var_pred <- sqrt(var_pred)
  # Add a buffer
  new_pred$.bound <- multiplier * var_pred

# ------------------------------------------------------------------------------
# helpers for trial values

# ------------------------------------------------------------------------------
# For a trial value of the bounds, fit the model and return statistics
# combine train_data and new_data

trial_fit <- function(trial, trial_data, wflow, level) {
  y_name <- get_outcome_name(wflow)

  # trial_data is training set with one extra row for the sample being inferred
  # Last value of the outcome vector is NA; add our trial value there
  trial_data[[y_name]][nrow(trial_data)] <- trial
  tmp_fit <- try(fit(wflow, trial_data), silent = TRUE)
  if (inherits(tmp_fit, "try-error")) {
      dplyr::tibble(quantile = NA_real_, trial = trial, .abs_resid = NA_real_)

  # Compute the abs residuals then more statistics
  tmp_res <-
    augment(tmp_fit, trial_data) %>%
      .abs_resid = abs(!!rlang::sym(y_name) - .pred)
  # Quantiles of original training set (abs) residuals
  quant_val <- stats::quantile(tmp_res$.abs_resid[-1], prob = level)

  res <-
      quantile = unname(quant_val),
      trial = trial,
      .abs_resid = tmp_res$.abs_resid[nrow(trial_data)]
  res$difference <- res$.abs_resid - res$quantile

# ------------------------------------------------------------------------------

grid_all <- function(new_data, model, train_data, level, ctrl) {
    .progress = ctrl$progress,
    .options = furrr::furrr_options(seed = ctrl$seed, stdout = FALSE)

grid_one <- function(new_data, model, train_data, level, ctrl) {
  y_name <- get_outcome_name(model)

  pkg_inst <- lapply(ctrl$required_pkgs, rlang::check_installed)
  pred_val <- new_data$.pred
  bound <- new_data$.bound
  new_data <- dplyr::select(new_data, -.pred, -.bound)
  new_data[[y_name]] <- NA_real_
  trial_data <- dplyr::bind_rows(train_data, new_data)

  trial_vals <- seq(pred_val - bound, pred_val + bound, length.out = ctrl$trial_points)

  res <-
      ~ trial_fit(.x,
        trial_data = trial_data,
        wflow = model,
        level = level

  compute_bound(res, pred_val)

compute_bound <- function(x, predicted) {
  x <- x[stats::complete.cases(x), ]
  if (all(x$difference < 0)) {
    rlang::warn("Could not determine bounds.")
    res <-
        .pred_lower = NA_real_,
        .pred_upper = NA_real_

  upper <- x[x$trial >= predicted & x$difference >= 0, ]
  if (nrow(upper) > 0) {
    upper <- min(upper$trial)
  } else {
    upper <- NA_real_
  lower <- x[x$trial <= predicted & x$difference >= 0, ]
  if (nrow(lower) > 0) {
    lower <- max(lower$trial)
  } else {
    lower <- NA_real_
  dplyr::tibble(.pred_lower = lower, .pred_upper = upper)

# ------------------------------------------------------------------------------

optimize_all <- function(new_data, model, train_data, level, ctrl) {
  # purrr::map_dfr(
    model = model,
    train_data = train_data,
    level = level,
    ctrl = ctrl,
    .progress = ctrl$progress,
    .options = furrr::furrr_options(seed = ctrl$seed, stdout = FALSE)

get_diff <- function(val, ...) {
  res <- trial_fit(val, ...)

optimize_one <- function(new_data, model, train_data, level, ctrl) {
  y_name <- get_outcome_name(model)

  pkg_inst <- lapply(ctrl$required_pkgs, rlang::check_installed)
  pred_val <- new_data$.pred
  bound <- new_data$.bound
  new_data <- dplyr::select(new_data, -.pred, -.bound)
  new_data[[y_name]] <- NA_real_
  trial_data <- dplyr::bind_rows(train_data, new_data)

  upper <-
      stats::uniroot(get_diff, c(pred_val, pred_val + bound),
        maxiter = ctrl$max_iter, tol = ctrl$tolerance,
        extendInt = "upX",
        trial_data, model, level
      silent = TRUE

  lower <-
      stats::uniroot(get_diff, c(pred_val - bound, pred_val),
        maxiter = ctrl$max_iter, tol = ctrl$tolerance,
        extendInt = "downX",
        trial_data, model, level
      silent = TRUE

    .pred_lower = get_root(lower, ctrl),
    .pred_upper = get_root(upper, ctrl)

get_root <- function(x, ctrl) {
  if (inherits(x, "try-error")) {
    msg <- c(
      "Could not finish the search process due to the following error:",
      "i" = conditionMessage(attr(x, "condition"))
  if (x$iter == ctrl$max_iter) {
    rlang::warn("Search did not converge.")

# ------------------------------------------------------------------------------

#' Controlling the numeric details for conformal inference
#' @param method The method for computing the intervals. The options are
#' `'search'` (using) [stats::uniroot()], and `'grid'`.
#' @param trial_points When `method = "grid"`, how many points should be
#' evaluated?
#' @param var_multiplier A multiplier for the variance model that determines the
#' possible range of the bounds.
#' @param max_iter When `method = "iterative"`, the maximum number of iterations.
#' @param tolerance Tolerance value passed to [all.equal()] to determine
#' convergence during the search computations.
#' @param progress Should a progress bar be used to track execution?
#' @param required_pkgs An optional character string for which packages are
#' required.
#' @param seed A single integer used to control randomness when models are
#' (re)fit.
#' @return A list object with the options given by the user.
#' @export
control_conformal_full <-
  function(method = "iterative", trial_points = 100, var_multiplier = 10,
           max_iter = 100, tolerance = .Machine$double.eps^0.25, progress = FALSE,
           required_pkgs = character(0), seed = sample.int(10^5, 1)) {
    method <- rlang::arg_match0(method, c("iterative", "grid"))

      method = method,
      trial_points = trial_points,
      var_multiplier = var_multiplier,
      max_iter = max_iter,
      tolerance = tolerance,
      progress = progress,
      required_pkgs = required_pkgs,
      seed = as.integer(seed)[1]
