View source: R/gen-train-nn-parsnip.R
| train_nnsnip | R Documentation |
train_nn()train_nnsnip() defines a neural network model specification that can be used
for classification or regression. It integrates with the tidymodels ecosystem
and uses train_nn() as the fitting backend, supporting any architecture
expressible via nn_arch() — feedforward, recurrent, convolutional, and beyond.
train_nnsnip(
mode = "unknown",
engine = "kindling",
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = NULL,
epochs = NULL,
batch_size = NULL,
penalty = NULL,
mixture = NULL,
learn_rate = NULL,
optimizer = NULL,
validation_split = NULL,
optimizer_args = NULL,
loss = NULL,
architecture = NULL,
flatten_input = NULL,
early_stopping = NULL,
device = NULL,
verbose = NULL,
cache_weights = NULL
)
mode |
A single character string for the type of model. Possible values are "unknown", "regression", or "classification". |
engine |
A single character string specifying what computational engine to use for fitting. Currently only "kindling" is supported. |
|
An integer vector for the number of units in each hidden layer. Can be tuned. | |
activations |
A character vector of activation function names for each hidden layer (e.g., "relu", "tanh", "sigmoid"). Can be tuned. |
output_activation |
A character string for the output activation function. Can be tuned. |
bias |
Logical for whether to include bias terms. Can be tuned. |
epochs |
An integer for the number of training iterations. Can be tuned. |
batch_size |
An integer for the batch size during training. Can be tuned. |
penalty |
A number for the regularization penalty (lambda). Default |
mixture |
A number between 0 and 1 for the elastic net mixing parameter.
Default
|
learn_rate |
A number for the learning rate. Can be tuned. |
optimizer |
A character string for the optimizer type ("adam", "sgd", "rmsprop"). Can be tuned. |
validation_split |
A number between 0 and 1 for the proportion of data used for validation. Can be tuned. |
optimizer_args |
A named list of additional arguments passed to the
optimizer. Cannot be tuned — pass via |
loss |
A character string or a valid |
architecture |
An |
flatten_input |
Logical or |
early_stopping |
An |
device |
A character string for the device to use ("cpu", "cuda", "mps").
If |
verbose |
Logical for whether to print training progress. Default |
cache_weights |
Logical. If |
This function creates a model specification for a neural network that can be
used within tidymodels workflows. The underlying engine is train_nn(), which
is architecture-agnostic: when architecture = NULL it falls back to a
standard feed-forward network, but any architecture expressible via nn_arch()
can be used instead. The model supports:
Configurable hidden layers and activation functions (default MLP path)
Custom architectures via nn_arch() (recurrent, convolutional, etc.)
GPU acceleration (CUDA, MPS, or CPU)
Hyperparameter tuning integration
Both regression and classification tasks
When using the default MLP path (no custom architecture), hidden_neurons
accepts an integer vector where each element represents the number of neurons
in that hidden layer. For example, hidden_neurons = c(128, 64, 32) creates
a network with three hidden layers. Pass an nn_arch() object via
set_engine() to use a custom architecture instead.
The device parameter controls where computation occurs:
NULL (default): Auto-detect best available device (CUDA > MPS > CPU)
"cuda": Use NVIDIA GPU
"mps": Use Apple Silicon GPU
"cpu": Use CPU only
When tuning, you can use special tune tokens:
For hidden_neurons: use tune("hidden_neurons") with a custom range
For activation: use tune("activation") with values like "relu", "tanh"
A model specification object with class train_nnsnip.
if (torch::torch_is_installed()) {
box::use(
recipes[recipe],
workflows[workflow, add_recipe, add_model],
tune[tune],
parsnip[fit]
)
# Model spec
nn_spec = train_nnsnip(
mode = "classification",
hidden_neurons = c(30, 5),
activations = c("relu", "elu"),
epochs = 100
)
wf = workflow() |>
add_recipe(recipe(Species ~ ., data = iris)) |>
add_model(nn_spec)
fit_wf = fit(wf, data = iris)
} else {
message("Torch not fully installed — skipping example")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.