augment: Augment data with predictions

augment.model_fitR Documentation

Augment data with predictions

Description

augment() will add column(s) for predictions to the given data.

Usage

## S3 method for class 'model_fit'
augment(x, new_data, eval_time = NULL, ...)

Arguments

x

A model_fit object produced by fit.model_spec() or fit_xy.model_spec().

new_data

A data frame or matrix.

eval_time

For censored regression models, a vector of time points at which the survival probability is estimated.

...

Not currently used.

Details

Regression

For regression models, a .pred column is added. If x was created using fit.model_spec() and new_data contains a regression outcome column, a .resid column is also added.

Classification

For classification models, the results can include a column called .pred_class as well as class probability columns named ⁠.pred_{level}⁠. This depends on what type of prediction types are available for the model.

Censored Regression

For these models, predictions for the expected time and survival probability are created (if the model engine supports them). If the model supports survival prediction, the eval_time argument is required.

If survival predictions are created and new_data contains a survival::Surv() object, additional columns are added for inverse probability of censoring weights (IPCW) are also created (see tidymodels.org page in the references below). This enables the user to compute performance metrics in the yardstick package.

References

https://www.tidymodels.org/learn/statistics/survival-metrics/

Examples


car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]

reg_form <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = car_trn)
reg_xy <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(car_trn[, -1], car_trn$mpg)

augment(reg_form, car_tst)
augment(reg_form, car_tst[, -1])

augment(reg_xy, car_tst)
augment(reg_xy, car_tst[, -1])

# ------------------------------------------------------------------------------

data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[  1:10 , ]

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = cls_trn)
cls_xy <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit_xy(cls_trn[, -3],
  cls_trn$Class)

augment(cls_form, cls_tst)
augment(cls_form, cls_tst[, -3])

augment(cls_xy, cls_tst)
augment(cls_xy, cls_tst[, -3])


parsnip documentation built on June 24, 2024, 5:14 p.m.