View source: R/multinomial_reg-predict.R
predict.brulee_multinomial_reg | R Documentation |
brulee_multinomial_reg
Predict from a brulee_multinomial_reg
## S3 method for class 'brulee_multinomial_reg'
predict(object, new_data, type = NULL, epoch = NULL, ...)
object |
A |
new_data |
A data frame or matrix of new predictors. |
type |
A single character. The type of predictions to generate. Valid options are:
|
epoch |
An integer for the epoch to make predictions. If this value
is larger than the maximum number that was fit, a warning is issued and the
parameters from the last epoch are used. If left |
... |
Not used, but required for extensibility. |
A tibble of predictions. The number of rows in the tibble is guaranteed
to be the same as the number of rows in new_data
.
if (torch::torch_is_installed()) {
library(recipes)
library(yardstick)
data(penguins, package = "modeldata")
penguins <- penguins %>% na.omit()
set.seed(122)
in_train <- sample(1:nrow(penguins), 200)
penguins_train <- penguins[ in_train,]
penguins_test <- penguins[-in_train,]
rec <- recipe(island ~ ., data = penguins_train) %>%
step_dummy(species, sex) %>%
step_normalize(all_numeric_predictors())
set.seed(3)
fit <- brulee_multinomial_reg(rec, data = penguins_train, epochs = 5)
fit
predict(fit, penguins_test) %>%
bind_cols(penguins_test) %>%
conf_mat(island, .pred_class)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.