mmn: Train and evaluate a Multi-Modal Neural Network (MMN) model

View source: R/mmn.R

mmnR Documentation

Train and evaluate a Multi-Modal Neural Network (MMN) model

Description

This function trains a Multi-Modal Neural Network (MMN) model on the provided data.

Usage

mmn(
  formula,
  dataList = NULL,
  fusion_hidden = c(50L, 50L),
  fusion_activation = c("relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu",
    "softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh",
    "tanhshrink", "softshrink", "hardshrink", "log_sigmoid"),
  fusion_bias = TRUE,
  fusion_dropout = 0,
  loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson"),
  optimizer = c("sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop"),
  lr = 0.01,
  alpha = 0.5,
  lambda = 0,
  validation = 0,
  batchsize = 32L,
  burnin = 10,
  shuffle = TRUE,
  epochs = 100,
  early_stopping = NULL,
  lr_scheduler = NULL,
  custom_parameters = NULL,
  device = c("cpu", "cuda", "mps"),
  plot = TRUE,
  verbose = TRUE
)

Arguments

formula

A formula object specifying the model structure. See examples for more information

dataList

A list containing the data for training the model. The list should contain all variables used in the formula.

fusion_hidden

A numeric vector specifying the number of units in each hidden layer of the fusion network.

fusion_activation

A character vector specifying the activation function for each hidden layer of the fusion network. Available options are: "relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu", "softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh", "tanhshrink", "softshrink", "hardshrink", "log_sigmoid".

fusion_bias

A logical value or vector (length(fusion_hidden) + 1) indicating whether to include bias terms in the layers of the fusion network.

fusion_dropout

The dropout rate for the fusion network, a numeric value or vector (length(fusion_hidden)) between 0 and 1.

loss

The loss function to be optimized during training. Available options are: "mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson".

optimizer

The optimization algorithm to be used during training. Available options are: "sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop".

lr

The learning rate for the optimizer.

alpha

The alpha parameter for elastic net regularization. Should be a value between 0 and 1.

lambda

The lambda parameter for elastic net regularization. Should be a positive value.

validation

The proportion of the training data to use for validation. Should be a value between 0 and 1.

batchsize

The batch size used during training.

burnin

training is aborted if the trainings loss is not below the baseline loss after burnin epochs

shuffle

A logical indicating whether to shuffle the training data in each epoch.

epochs

The number of epochs to train the model.

early_stopping

If provided, the training will stop if the validation loss does not improve for the specified number of epochs. If set to NULL, early stopping is disabled.

lr_scheduler

Learning rate scheduler created with config_lr_scheduler

custom_parameters

A list of parameters used by custom loss functions. See vignette for examples.

device

The device on which to perform computations. Available options are: "cpu", "cuda", "mps".

plot

A logical indicating whether to plot training and validation loss curves.

verbose

A logical indicating whether to display verbose output during training.

Value

An object of class "citommn" containing the trained MMN model and other information.

See Also

predict.citommn, print.citommn, summary.citommn, continue_training, analyze_training


citoverse/cito documentation built on Jan. 16, 2025, 11:49 p.m.