augment: Augment data with predictions

augment.model_fitR Documentation

Augment data with predictions


augment() will add column(s) for predictions to the given data.


## S3 method for class 'model_fit'
augment(x, new_data, eval_time = NULL, ...)



A model_fit object produced by fit.model_spec() or fit_xy.model_spec().


A data frame or matrix.


For censored regression models, a vector of time points at which the survival probability is estimated.


Not currently used.



For regression models, a .pred column is added. If x was created using fit.model_spec() and new_data contains a regression outcome column, a .resid column is also added.


For classification models, the results can include a column called .pred_class as well as class probability columns named ⁠.pred_{level}⁠. This depends on what type of prediction types are available for the model.

Censored Regression

For these models, predictions for the expected time and survival probability are created (if the model engine supports them). If the model supports survival prediction, the eval_time argument is required.

If survival predictions are created and new_data contains a survival::Surv() object, additional columns are added for inverse probability of censoring weights (IPCW) are also created (see page in the references below). This enables the user to compute performance metrics in the yardstick package.



car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]

reg_form <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = car_trn)
reg_xy <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(car_trn[, -1], car_trn$mpg)

augment(reg_form, car_tst)
augment(reg_form, car_tst[, -1])

augment(reg_xy, car_tst)
augment(reg_xy, car_tst[, -1])

# ------------------------------------------------------------------------------

data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[  1:10 , ]

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = cls_trn)
cls_xy <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit_xy(cls_trn[, -3],

augment(cls_form, cls_tst)
augment(cls_form, cls_tst[, -3])

augment(cls_xy, cls_tst)
augment(cls_xy, cls_tst[, -3])

parsnip documentation built on Aug. 18, 2023, 1:07 a.m.