predict.model_fit: Model predictions

Description Usage Arguments Details Value Examples

View source: R/predict.R

Description

Apply a model to create different types of predictions. predict() can be used for all types of models and uses the "type" argument for more specificity.

Usage

1
2
3
4
5
6
7
## S3 method for class 'model_fit'
predict(object, new_data, type = NULL, opts = list(), ...)

## S3 method for class 'model_fit'
predict_raw(object, new_data, opts = list(), ...)

predict_raw(object, ...)

Arguments

object

An object of class model_fit

new_data

A rectangular data object, such as a data frame.

type

A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time", "hazard", "survival", or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

opts

A list of optional arguments to the underlying predict function that will be used when type = "raw". The list should not include options for the model object or the new data being predicted.

...

Arguments to the underlying model's prediction function cannot be passed here (see opts). There are some parsnip related options that can be passed, depending on the value of type. Possible arguments are:

  • level: for types of "conf_int" and "pred_int" this is the parameter for the tail area of the intervals (e.g. confidence level for confidence intervals). Default value is 0.95.

  • std_error: add the standard error of fit or prediction (on the scale of the linear predictors) for types of "conf_int" and "pred_int". Default value is FALSE.

  • quantile: the quantile(s) for quantile regression (not implemented yet)

  • time: the time(s) for hazard and survival probability estimates.

Details

If "type" is not supplied to predict(), then a choice is made:

predict() is designed to provide a tidy result (see "Value" section below) in a tibble output format.

Interval predictions

When using type = "conf_int" and type = "pred_int", the options level and std_error can be used. The latter is a logical for an extra column of standard error values (if available).

Censored regression predictions

For censored regression, a numeric vector for time is required when survival or hazard probabilities are requested. Also, when type = "linear_pred", censored regression models will be formatted such that the linear predictor increases with time. This may have the opposite sign as what the underlying model's predict() method produces.

Value

With the exception of type = "raw", the results of predict.model_fit() will be a tibble as many rows in the output as there are rows in new_data and the column names will be predictable.

For numeric results with a single outcome, the tibble will have a .pred column and .pred_Yname for multivariate results.

For hard class predictions, the column is named .pred_class and, when type = "prob", the columns are .pred_classlevel.

type = "conf_int" and type = "pred_int" return tibbles with columns .pred_lower and .pred_upper with an attribute for the confidence level. In the case where intervals can be produces for class probabilities (or other non-scalar outputs), the columns will be named .pred_lower_classlevel and so on.

Quantile predictions return a tibble with a column .pred, which is a list-column. Each list element contains a tibble with columns .pred and .quantile (and perhaps other columns).

Using type = "raw" with predict.model_fit() will return the unadulterated results of the prediction function.

For censored regression:

For the last two types, the results are a nested tibble with an overall column called .pred with sub-tibbles with the above format.

In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but type-specific prediction functions.

When the model fit failed and the error was captured, the predict() function will return the same structure as above but filled with missing values. This does not currently work for multivariate models.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
library(dplyr)

lm_model <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))

pred_cars <-
  mtcars %>%
  dplyr::slice(1:10) %>%
  dplyr::select(-mpg)

predict(lm_model, pred_cars)

predict(
  lm_model,
  pred_cars,
  type = "conf_int",
  level = 0.90
)

predict(
  lm_model,
  pred_cars,
  type = "raw",
  opts = list(type = "terms")
)

Example output

Attaching package:dplyrThe following objects are masked frompackage:stats:

    filter, lag

The following objects are masked frompackage:base:

    intersect, setdiff, setequal, union

# A tibble: 10 x 1
   .pred
   <dbl>
 1  23.4
 2  23.3
 3  27.6
 4  21.5
 5  17.6
 6  21.6
 7  13.9
 8  21.7
 9  25.6
10  17.1
# A tibble: 10 x 2
   .pred_lower .pred_upper
         <dbl>       <dbl>
 1       17.9         29.0
 2       18.1         28.5
 3       24.0         31.3
 4       17.5         25.6
 5       14.3         20.8
 6       17.0         26.2
 7        9.65        18.2
 8       16.2         27.2
 9       14.2         37.0
10       11.5         22.7
                           cyl       disp         hp        drat         wt
Mazda RX4         -0.001433177 -0.8113275  0.6303467 -0.06120265  2.4139815
Mazda RX4 Wag     -0.001433177 -0.8113275  0.6303467 -0.06120265  1.4488706
Datsun 710        -0.009315653 -1.3336453  0.8557288 -0.05014798  3.5494061
Hornet 4 Drive    -0.001433177  0.1730406  0.6303467  0.12009386  0.1620561
Hornet Sportabout  0.006449298  1.1975870 -0.2314083  0.10461733 -0.6895124
Valiant           -0.001433177 -0.1584303  0.6966356  0.19084372 -0.7652074
Duster 360         0.006449298  1.1975870 -1.1594522  0.09135173 -1.1815297
Merc 240D         -0.009315653 -0.9449204  1.2667197 -0.01477305  0.2566748
Merc 230          -0.009315653 -1.0041833  0.8292133 -0.06562451  0.4080647
Merc 280          -0.001433177 -0.7349888  0.4579957 -0.06562451 -0.6895124
                       qsec         vs       am        gear       carb
Mazda RX4         -1.567729  0.2006406  2.88774  0.02512680 -0.2497240
Mazda RX4 Wag     -0.736286  0.2006406  2.88774  0.02512680 -0.2497240
Datsun 710         1.624418 -0.3511210  2.88774  0.02512680  0.4668753
Hornet 4 Drive     2.856736 -0.3511210 -2.40645 -0.06700481  0.4668753
Hornet Sportabout -0.736286  0.2006406 -2.40645 -0.06700481  0.2280089
Valiant            4.014817 -0.3511210 -2.40645 -0.06700481  0.4668753
Duster 360        -2.488255  0.2006406 -2.40645 -0.06700481 -0.2497240
Merc 240D          3.688179 -0.3511210 -2.40645  0.02512680  0.2280089
Merc 230           7.993866 -0.3511210 -2.40645  0.02512680  0.2280089
Merc 280           1.164155 -0.3511210 -2.40645  0.02512680 -0.2497240
attr(,"constant")
[1] 19.96364

parsnip documentation built on July 21, 2021, 5:08 p.m.