Description Usage Arguments Details Value Examples
View source: R/lantern_logistic_reg-fit.R
lantern_logistic_reg()
fits a model.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | lantern_logistic_reg(x, ...)
## Default S3 method:
lantern_logistic_reg(x, ...)
## S3 method for class 'data.frame'
lantern_logistic_reg(
x,
y,
epochs = 100L,
penalty = 0,
validation = 0.1,
learn_rate = 0.01,
momentum = 0,
batch_size = NULL,
conv_crit = -Inf,
verbose = FALSE,
...
)
## S3 method for class 'matrix'
lantern_logistic_reg(
x,
y,
epochs = 100L,
penalty = 0,
validation = 0.1,
learn_rate = 0.01,
momentum = 0,
batch_size = NULL,
conv_crit = -Inf,
verbose = FALSE,
...
)
## S3 method for class 'formula'
lantern_logistic_reg(
formula,
data,
epochs = 100L,
penalty = 0,
validation = 0.1,
learn_rate = 0.01,
momentum = 0,
batch_size = NULL,
conv_crit = -Inf,
verbose = FALSE,
...
)
## S3 method for class 'recipe'
lantern_logistic_reg(
x,
data,
epochs = 100L,
penalty = 0,
validation = 0.1,
learn_rate = 0.01,
momentum = 0,
batch_size = NULL,
conv_crit = -Inf,
verbose = FALSE,
...
)
|
x |
Depending on the context:
The predictor data should be standardized (e.g. centered or scaled). |
... |
Not currently used, but required for extensibility. |
y |
When
|
epochs |
An integer for the number of epochs of training. |
penalty |
The amount of weight decay (i.e., L2 regularization). |
validation |
The proportion of the data randomly assigned to a validation set. |
learn_rate |
A positive number (usually less than 0.1). |
momentum |
A positive number on |
batch_size |
An integer for the number of training set points in each batch. |
conv_crit |
A non-negative number for convergence. |
verbose |
A logical that prints out the iteration history. |
formula |
A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side. |
data |
When a recipe or formula is used,
|
Despite its name, this function can be used with three or more classes (e.g., multinomial regression).
The predictors data should all be numeric and encoded in the same units (e.g. standardized to the same range or distribution). If there are factor predictors, use a recipe or formula to create indicator variables (or some other method) to make them numeric.
If conv_crit
is used, it stops training when the difference in the loss
function is below conv_crit
or if it gets worse. The default trains the
model over the specified number of epochs.
A lantern_logistic_reg
object with elements:
models
: a list object of serialized models for each epoch.
loss
: A vector of loss values (MSE for regression, negative log-
likelihood for classification) at each epoch.
dim
: A list of data dimensions.
y_stats
: A list of summary statistics for numeric outcomes.
parameters
: A list of some tuning parameter values.
blueprint
: The hardhat
blueprint data.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | if (torch::torch_is_installed()) {
## -----------------------------------------------------------------------------
# increase # epochs to get better results
data(cells, package = "modeldata")
cells$case <- NULL
set.seed(122)
in_train <- sample(1:nrow(cells), 1000)
cells_train <- cells[ in_train,]
cells_test <- cells[-in_train,]
# Using matrices
set.seed(1)
lantern_logistic_reg(x = as.matrix(cells_train[, c("fiber_width_ch_1", "width_ch_1")]),
y = cells_train$class,
penalty = 0.10, epochs = 20L, batch_size = 32)
# Using recipe
library(recipes)
cells_rec <-
recipe(class ~ ., data = cells_train) %>%
# Transform some highly skewed predictors
step_YeoJohnson(all_predictors()) %>%
step_normalize(all_predictors())
set.seed(2)
fit <- lantern_logistic_reg(cells_rec, data = cells_train,
penalty = .01, epochs = 100L, batch_size = 32)
fit
autoplot(fit)
library(yardstick)
predict(fit, cells_test, type = "prob") %>%
bind_cols(cells_test) %>%
roc_auc(class, .pred_PS)
# ------------------------------------------------------------------------------
# multinomial regression
data(penguins, package = "modeldata")
penguins <- penguins %>% na.omit()
set.seed(122)
in_train <- sample(1:nrow(penguins), 200)
penguins_train <- penguins[ in_train,]
penguins_test <- penguins[-in_train,]
rec <- recipe(island ~ ., data = penguins_train) %>%
step_dummy(species, sex) %>%
step_normalize(all_predictors())
set.seed(3)
fit <- lantern_logistic_reg(rec, data = penguins_train,
epochs = 200L, batch_size = 32)
fit
predict(fit, penguins_test) %>%
bind_cols(penguins_test) %>%
conf_mat(island, .pred_class)
}
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.