| gen-nn-predict | R Documentation |
Generate predictions from an "nn_fit" object produced by train_nn().
Three S3 methods are registered:
predict.nn_fit() — base method for matrix-trained models.
predict.nn_fit_tab() — extends the base method for tabular fits; runs new
data through hardhat::forge() before predicting.
predict.nn_fit_ds() — extends the base method for torch dataset fits.
## S3 method for class 'nn_fit'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
## S3 method for class 'nn_fit_tab'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
## S3 method for class 'nn_fit_ds'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
object |
A fitted model object returned by |
newdata |
New predictor data. Accepted forms depend on the method:
|
new_data |
Legacy alias for |
type |
Character. Output type:
|
... |
Currently unused; reserved for future extensions. |
Regression: a numeric vector (single output) or matrix (multiple outputs).
Classification, type = "response": a factor with levels matching those
seen during training.
Classification, type = "prob": a numeric matrix with one column per
class, columns named by class label.
train_nn()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.