augment.model_fit | R Documentation |
augment()
will add column(s) for predictions to the given data.
## S3 method for class 'model_fit'
augment(x, new_data, eval_time = NULL, ...)
x |
A |
new_data |
A data frame or matrix. |
eval_time |
For censored regression models, a vector of time points at which the survival probability is estimated. |
... |
Not currently used. |
For regression models, a .pred
column is added. If x
was created using
fit.model_spec()
and new_data
contains a regression outcome column, a
.resid
column is also added.
For classification models, the results can include a column called
.pred_class
as well as class probability columns named .pred_{level}
.
This depends on what type of prediction types are available for the model.
For these models, predictions for the expected time and survival probability
are created (if the model engine supports them). If the model supports
survival prediction, the eval_time
argument is required.
If survival predictions are created and new_data
contains a
survival::Surv()
object, additional columns are added for inverse
probability of censoring weights (IPCW) are also created (see tidymodels.org
page in the references below). This enables the user to compute performance
metrics in the yardstick package.
https://www.tidymodels.org/learn/statistics/survival-metrics/
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
augment(reg_form, car_tst[, -1])
augment(reg_xy, car_tst)
augment(reg_xy, car_tst[, -1])
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
augment(cls_form, cls_tst[, -3])
augment(cls_xy, cls_tst)
augment(cls_xy, cls_tst[, -3])
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.