mmn | R Documentation |
This function trains a Multi-Modal Neural Network (MMN) model on the provided data.
mmn(
formula,
dataList = NULL,
fusion_hidden = c(50L, 50L),
fusion_activation = c("relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu",
"softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh",
"tanhshrink", "softshrink", "hardshrink", "log_sigmoid"),
fusion_bias = TRUE,
fusion_dropout = 0,
loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson"),
optimizer = c("sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop"),
lr = 0.01,
alpha = 0.5,
lambda = 0,
validation = 0,
batchsize = 32L,
burnin = 10,
shuffle = TRUE,
epochs = 100,
early_stopping = NULL,
lr_scheduler = NULL,
custom_parameters = NULL,
device = c("cpu", "cuda", "mps"),
plot = TRUE,
verbose = TRUE
)
formula |
A formula object specifying the model structure. See examples for more information |
dataList |
A list containing the data for training the model. The list should contain all variables used in the formula. |
A numeric vector specifying the number of units in each hidden layer of the fusion network. | |
fusion_activation |
A character vector specifying the activation function for each hidden layer of the fusion network. Available options are: "relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu", "softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh", "tanhshrink", "softshrink", "hardshrink", "log_sigmoid". |
fusion_bias |
A logical value or vector (length(fusion_hidden) + 1) indicating whether to include bias terms in the layers of the fusion network. |
fusion_dropout |
The dropout rate for the fusion network, a numeric value or vector (length(fusion_hidden)) between 0 and 1. |
loss |
The loss function to be optimized during training. Available options are: "mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson". |
optimizer |
The optimization algorithm to be used during training. Available options are: "sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop". |
lr |
The learning rate for the optimizer. |
alpha |
The alpha parameter for elastic net regularization. Should be a value between 0 and 1. |
lambda |
The lambda parameter for elastic net regularization. Should be a positive value. |
validation |
The proportion of the training data to use for validation. Should be a value between 0 and 1. |
batchsize |
The batch size used during training. |
burnin |
training is aborted if the trainings loss is not below the baseline loss after burnin epochs |
shuffle |
A logical indicating whether to shuffle the training data in each epoch. |
epochs |
The number of epochs to train the model. |
early_stopping |
If provided, the training will stop if the validation loss does not improve for the specified number of epochs. If set to NULL, early stopping is disabled. |
lr_scheduler |
Learning rate scheduler created with |
custom_parameters |
A list of parameters used by custom loss functions. See vignette for examples. |
device |
The device on which to perform computations. Available options are: "cpu", "cuda", "mps". |
plot |
A logical indicating whether to plot training and validation loss curves. |
verbose |
A logical indicating whether to display verbose output during training. |
An object of class "citommn" containing the trained MMN model and other information.
predict.citommn
, print.citommn
, summary.citommn
, continue_training
, analyze_training
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.