| gen-nn-train | R Documentation |
train_nn() is a generic function for training neural networks with a
user-defined architecture via nn_arch(). Dispatch is based on the class
of x.
Recommended workflow:
Define architecture with nn_arch() (optional).
Train with train_nn().
Predict with predict.nn_fit().
All methods delegate to a shared implementation core after preprocessing.
When architecture = NULL, the model falls back to a plain feed-forward neural network
(nn_linear) architecture.
train_nn(x, ...)
## S3 method for class 'matrix'
train_nn(
x,
y,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
## S3 method for class 'data.frame'
train_nn(
x,
y,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
## S3 method for class 'formula'
train_nn(
x,
data,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
## Default S3 method:
train_nn(x, ...)
## S3 method for class 'dataset'
train_nn(
x,
y = NULL,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
flatten_input = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
n_classes = NULL,
...
)
x |
Dispatch is based on its current class:
|
... |
Additional arguments passed to specific methods. |
y |
Outcome data. Interpretation depends on the method:
|
|
Integer vector specifying the number of neurons in each
hidden layer, e.g. | |
activations |
Activation function specification(s) for the hidden layers.
See |
output_activation |
Optional activation function for the output layer.
Defaults to |
bias |
Logical. Whether to include bias terms in each layer. Default |
arch |
Backward-compatible alias for |
architecture |
An |
early_stopping |
An |
epochs |
Positive integer. Number of full passes over the training data.
Default |
batch_size |
Positive integer. Number of samples per mini-batch. Default |
penalty |
Non-negative numeric. L1/L2 regularization strength (lambda).
Default |
mixture |
Numeric in [0, 1]. Elastic net mixing parameter: |
learn_rate |
Positive numeric. Step size for the optimizer. Default |
optimizer |
Character. Optimizer algorithm. One of |
optimizer_args |
Named list of additional arguments forwarded to the
optimizer constructor (e.g. |
loss |
Character or function. Loss function used during training. Built-in
options: |
validation_split |
Numeric in [0, 1). Proportion of training data held
out for validation. Default |
device |
Character. Compute device: |
verbose |
Logical. If |
cache_weights |
Logical. If |
data |
A data frame. Required when |
flatten_input |
Logical or |
n_classes |
Positive integer. Number of output classes. Required when
|
The returned "nn_fit" object is a named list with the following components:
model — the trained torch::nn_module object
fitted — fitted values on the training data (or NULL for dataset fits)
loss_history — numeric vector of per-epoch training loss, trimmed to
actual epochs run (relevant when early stopping is active)
val_loss_history — per-epoch validation loss, or NULL if
validation_split = 0
n_epochs — number of epochs actually trained
stopped_epoch — epoch at which early stopping triggered, or NA if
training ran to completion
hidden_neurons, activations, output_activation — architecture spec
penalty, mixture — regularization settings
feature_names, response_name — variable names (tabular methods only)
no_x, no_y — number of input features and output nodes
is_classification — logical flag
y_levels, n_classes — class labels and count (classification only)
device — device the model is on
cached_weights — list of weight matrices, or NULL
arch — the nn_arch object used, or NULL
An object of class "nn_fit", or one of its subclasses:
c("nn_fit_tab", "nn_fit") — returned by the data.frame and formula methods
c("nn_fit_ds", "nn_fit") — returned by the dataset method
All subclasses share a common structure. See Details for the list of components.
train_nn() is task-agnostic by design (no explicit task argument).
Task behavior is determined by your input interface and architecture:
Tabular data: use matrix, data.frame, or formula methods.
Time series: use the dataset method with per-item tensors shaped as
[time, features] (or your preferred convention) and a recurrent
architecture via nn_arch().
Image classification: use the dataset method with per-item tensors
shaped for your first layer (commonly [channels, height, width] for
torch::nn_conv2d). If your source arrays are channel-last, reorder in the
dataset or via input_transform.
When x is supplied as a raw numeric matrix, no preprocessing is applied.
Data is passed directly to the shared train_nn_impl core.
When x is a data frame, y can be either a vector / factor / matrix of
outcomes, or a formula of the form outcome ~ predictors evaluated against
x. Preprocessing is handled by hardhat::mold().
When x is a formula, data must be supplied as the data frame against
which the formula is evaluated. Preprocessing is handled by hardhat::mold().
train_nn.dataset())Trains a neural network directly on a torch dataset object. Batching and
lazy loading are handled by torch::dataloader(), making this method
well-suited for large datasets that do not fit entirely in memory.
Architecture configuration follows the same contract as other train_nn()
methods via architecture = nn_arch(...) (or legacy arch = ...).
For non-tabular inputs (time series, images), set flatten_input = FALSE to
preserve tensor dimensions expected by recurrent or convolutional layers.
Labels are taken from the second element of each dataset item (i.e.
dataset[[i]][[2]]), so y is ignored. When the label is a scalar tensor,
a classification task is assumed and n_classes must be supplied. The loss
is automatically switched to "cross_entropy" in that case.
Fitted values are not cached in the returned object. Use
predict.nn_fit_ds() with newdata to obtain predictions after training.
predict.nn_fit(), nn_arch(), act_funs(), early_stop()
if (torch::torch_is_installed()) {
# Matrix method — no preprocessing
model = train_nn(
x = as.matrix(iris[, 2:4]),
y = iris$Sepal.Length,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Data frame method — y as a vector
model = train_nn(
x = iris[, 2:4],
y = iris$Sepal.Length,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Data frame method — y as a formula evaluated against x
model = train_nn(
x = iris,
y = Sepal.Length ~ . - Species,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Formula method — outcome derived from formula
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# No hidden layers — linear model
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
epochs = 50
)
# Architecture object (nn_arch -> train_nn)
mlp_arch = nn_arch(nn_name = "mlp_model")
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
architecture = mlp_arch,
epochs = 50
)
# Custom layer architecture
custom_linear = torch::nn_module(
"CustomLinear",
initialize = function(in_features, out_features, bias = TRUE) {
self$layer = torch::nn_linear(in_features, out_features, bias = bias)
},
forward = function(x) self$layer(x)
)
custom_arch = nn_arch(
nn_name = "custom_linear_mlp",
nn_layer = ~ custom_linear
)
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(16, 8),
activations = "relu",
architecture = custom_arch,
epochs = 50
)
# With early stopping
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 200,
validation_split = 0.2,
early_stopping = early_stop(patience = 10)
)
}
if (torch::torch_is_installed()) {
# torch dataset method — labels come from the dataset itself
iris_cls_dataset = torch::dataset(
name = "iris_cls_dataset",
initialize = function(data = iris) {
self$x = torch::torch_tensor(
as.matrix(data[, 1:4]),
dtype = torch::torch_float32()
)
# Species is a factor; convert to integer (1-indexed -> keep as-is for cross_entropy)
self$y = torch::torch_tensor(
as.integer(data$Species),
dtype = torch::torch_long()
)
},
.getitem = function(i) {
list(self$x[i, ], self$y[i])
},
.length = function() {
self$x$size(1)
}
)()
model_nn_ds = train_nn(
x = iris_cls_dataset,
hidden_neurons = c(32, 10),
activations = "relu",
epochs = 80,
batch_size = 16,
learn_rate = 0.01,
n_classes = 3, # Iris dataset has only 3 species
validation_split = 0.2,
verbose = TRUE
)
pred_nn = predict(model_nn_ds, iris_cls_dataset)
class_preds = c("Setosa", "Versicolor", "Virginica")[predict(model_nn_ds, iris_cls_dataset)]
# Confusion Matrix
table(actual = iris$Species, pred = class_preds)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.