View source: R/conformal_infer_split.R
int_conformal_split | R Documentation |
Nonparametric prediction intervals can be computed for fitted regression workflow objects using the split conformal inference method described by Lei et al (2018).
int_conformal_split(object, ...)
## Default S3 method:
int_conformal_split(object, ...)
## S3 method for class 'workflow'
int_conformal_split(object, cal_data, ...)
object |
A fitted |
... |
Not currently used. |
cal_data |
A data frame with the original predictor and outcome data used to produce predictions (and residuals). If the workflow used a recipe, this should be the data that were inputs to the recipe (and not the product of a recipe). |
This function implements what is usually called "split conformal inference" (see Algorithm 1 in Lei et al (2018)).
This function prepares the statistics for the interval computations. The
predict()
method computes the intervals for new data and the signficance
level is specified there.
cal_data
should be large enough to get a good estimates of a extreme
quantile (e.g., the 95th for 95% interval) and should not include rows that
were in the original training set.
An object of class "int_conformal_split"
containing the
information to create intervals (which includes object
).
The predict()
method is used to produce the intervals.
Lei, Jing, et al. "Distribution-free predictive inference for regression." Journal of the American Statistical Association 113.523 (2018): 1094-1111.
predict.int_conformal_split()
library(workflows)
library(dplyr)
library(parsnip)
library(rsample)
library(tune)
library(modeldata)
set.seed(2)
sim_train <- sim_regression(500)
sim_cal <- sim_regression(200)
sim_new <- sim_regression(5) %>% select(-outcome)
# We'll use a neural network model
mlp_spec <-
mlp(hidden_units = 5, penalty = 0.01) %>%
set_mode("regression")
mlp_wflow <-
workflow() %>%
add_model(mlp_spec) %>%
add_formula(outcome ~ .)
mlp_fit <- fit(mlp_wflow, data = sim_train)
mlp_int <- int_conformal_split(mlp_fit, sim_cal)
mlp_int
predict(mlp_int, sim_new, level = 0.90)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.