int_conformal_full: Prediction intervals via conformal inference

View source: R/conformal_infer.R

int_conformal_fullR Documentation

Prediction intervals via conformal inference

Description

Nonparametric prediction intervals can be computed for fitted workflow objects using the conformal inference method described by Lei at al (2018).

Usage

int_conformal_full(object, ...)

## Default S3 method:
int_conformal_full(object, ...)

## S3 method for class 'workflow'
int_conformal_full(object, train_data, ..., control = control_conformal_full())

Arguments

object

A fitted workflows::workflow() object.

...

Not currently used.

train_data

A data frame with the original predictor data used to create the fitted workflow (predictors and outcomes). If the workflow does not contain these values, pass them here. If the workflow used a recipe, this should be the data that were inputs to the recipe (and not the product of a recipe).

control

A control object from control_conformal_full() with the numeric minutiae.

Details

This function implements what is usually called "full conformal inference" (see Algorithm 1 in Lei et al (2018)) since it uses the entire training set to compute the intervals.

This function prepares the objects for the computations. The predict() method computes the intervals for new data.

For a given new_data observation, the predictors are appended to the original training set. Then, different "trial" values of the outcome are substituted in for that observation's outcome and the model is re-fit. From each model, the residual associated with the trial value is compared to a quantile of the distribution of the other residuals. Usually the absolute values of the residuals are used. Once the residual of a trial value exceeds the distributional quantile, the value is one of the bounds.

The literature proposed using a grid search of trial values to find the two points that correspond to the prediction intervals. To use this approach, set method = "grid" in control_conformal_full(). However, the default method ⁠"search⁠ uses two different one-dimensional iterative searches on either side of the predicted value to find values that correspond to the prediction intervals.

For medium to large data sets, the iterative search method is likely to generate slightly smaller intervals. For small training sets, grid search is more likely to have somewhat smaller intervals (and will be more stable). Otherwise, the iterative search method is more precise and several folds faster.

To determine a range of possible values of the intervals, used by both methods, the initial set of training set residuals are modeled using a Gamma generalized linear model with a log link (see the reference by Aitkin below). For a new sample, the absolute size of the residual is estimated and a multiple of this value is computed as an initial guess of the search boundaries.

Speed

The time it takes to compute the intervals depends on the training set size, search parameters (i.e., convergence criterion, number of iterations), the grid size, and the number of worker processes that are used. For the last item, the computations can be parallelized using the future and furrr packages.

To use parallelism, the future::plan() function can be invoked to create a parallel backend. For example, let’s make an initial workflow:

library(tidymodels)
library(probably)
library(future)

tidymodels_prefer()

## Make a fitted workflow from some simulated data:
set.seed(121)
train_dat <- sim_regression(200)
new_dat   <- sim_regression(  5) %>% select(-outcome)

lm_fit <- 
  workflow() %>% 
  add_model(linear_reg()) %>% 
  add_formula(outcome ~ .) %>% 
  fit(data = train_dat)

# Create the object to be used to make prediction intervals
lm_conform <- int_conformal_full(lm_fit, train_dat)

We’ll use a "multisession" parallel processing plan to compute the intervals for the five new samples in parallel:

plan("multisession")

# This is run in parallel:
predict(lm_conform, new_dat)
## # A tibble: 5 x 2
##   .pred_lower .pred_upper
##         <dbl>       <dbl>
## 1       -17.9        59.6
## 2       -33.7        51.1
## 3       -30.6        48.2
## 4       -17.3        59.6
## 5       -23.3        55.2

Using simulations, there are slightly sub-linear speed-ups when using parallel processing to compute the row-wise intervals.

In comparison with parametric intervals:

predict(lm_fit, new_dat, type = "pred_int")
## # A tibble: 5 x 2
##   .pred_lower .pred_upper
##         <dbl>       <dbl>
## 1       -19.2        59.1
## 2       -31.8        49.7
## 3       -31.0        47.6
## 4       -17.8        60.1
## 5       -23.6        54.3

Value

An object of class "int_conformal_full" containing the information to create intervals (which includes the training set data). The predict() method is used to produce the intervals.

References

Jing Lei, Max G'Sell, Alessandro Rinaldo, Ryan J. Tibshirani and Larry Wasserman (2018) Distribution-Free Predictive Inference for Regression, Journal of the American Statistical Association, 113:523, 1094-1111

Murray Aitkin, Modelling Variance Heterogeneity in Normal Regression Using GLIM, Journal of the Royal Statistical Society Series C: Applied Statistics, Volume 36, Issue 3, November 1987, Pages 332–339.

See Also

predict.int_conformal_full()


probably documentation built on May 29, 2024, 4:44 a.m.