
Travis build
status Coverage
status lifecycle

safepredict has two goals: to provide a consistent interface to prediction via the safe_predict() generic, and to accurately quantify prediction uncertainty.


safepredict follows the tidymodels prediction specification.


safepredict is currently in the beginning stages of development and is available only on Github. You can install it with:

# install.packages("devtools")


The three main arguments to safe_predict() are always the same:


Suppose you fit a logistic regression using glm:


data <- tibble(
  y = as.factor(rep(c("A", "B"), each = 50)),
  x = c(rnorm(50, 1), rnorm(50, 3))

fit <- glm(y ~ x, data, family = binomial)

You can predict class probabilities:


test <- tibble(x = rnorm(10, 2))

safe_predict(fit, new_data = test, type = "prob")
#> # A tibble: 10 x 2
#>   .pred_A .pred_B
#>     <dbl>   <dbl>
#> 1  0.333    0.667
#> 2  0.0410   0.959
#> 3  0.619    0.381
#> 4  0.467    0.533
#> 5  0.132    0.868
#> # ... with 5 more rows

or can jump straight to hard class decisions

safe_predict(fit, new_data = test, type = "class")
#> # A tibble: 10 x 1
#>   .pred_class
#>   <fct>      
#> 1 B          
#> 2 B          
#> 3 A          
#> 4 B          
#> 5 B          
#> # ... with 5 more rows

We can also get predictions on the link scale:

safe_predict(fit, new_data = test, type = "link")
#> # A tibble: 10 x 1
#>    .pred
#>    <dbl>
#> 1  0.696
#> 2  3.15 
#> 3 -0.485
#> 4  0.132
#> 5  1.88 
#> # ... with 5 more rows

or we can get confidence intervals on the response scale

safe_predict(fit, new_data = test, type = "conf_int")
#> # A tibble: 10 x 3
#>   .pred .pred_lower .pred_upper
#>   <dbl>       <dbl>       <dbl>
#> 1 0.667       0.795       0.510
#> 2 0.959       0.989       0.862
#> 3 0.381       0.545       0.240
#> 4 0.533       0.680       0.380
#> 5 0.868       0.943       0.724
#> # ... with 5 more rows

Related work

alexpghayes/safepredict documentation built on May 29, 2019, 11:02 p.m.