View source: R/tabpfn-predict.R
| predict.tab_pfn | R Documentation |
TabPFNPredict using TabPFN
## S3 method for class 'tab_pfn'
predict(object, new_data, ...)
## S3 method for class 'tab_pfn'
augment(x, new_data, ...)
object, x |
A |
new_data |
A data frame or matrix of new predictors. |
... |
Not used, but required for extensibility. |
predict() returns a tibble of predictions and augment() appends the
columns in new_data. In either case, the number of rows in the tibble is
guaranteed to be the same as the number of rows in new_data.
For regression data, the prediction is in the column .pred. For
classification, the class predictions are in .pred_class and the
probability estimates are in columns with the pattern .pred_{level} where
level is the levels of the outcome factor vector.
# Minimal example for quick execution
car_train <- mtcars[ 1:5, ]
car_test <- mtcars[6, -1]
## Not run:
# Fit
if (is_tab_pfn_installed() & interactive()) {
mod <- tab_pfn(mpg ~ cyl + log(drat), car_train)
# Predict
predict(mod, car_test)
augment(mod, car_test)
}
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.