train_nn: Train Network

Description Usage Arguments Value Examples

View source: R/nn-train.R

Description

Train the network with specified hyperparameters and return the trained model.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
train_nn(
  train_data,
  train_target,
  validate_data,
  validate_target,
  model,
  alpha,
  epochs,
  batch_size = nrow(train_data),
  plot_acc = TRUE
)

Arguments

train_data

set of training data

train_target

set of training data targets in one-hot encoded form

validate_data

set of validation data targets in one-hot encoded form

validate_target

set of targets in

model

list of weights and biases

alpha

learning rate

epochs

number of epochs

batch_size

mini-batch size

plot_acc

whether or not to plot training and validation accuracy

Value

list of weights and biases after training

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
## Not run: 
mlp_model <- init_nn(784, 100, 50, 10)
mnist <- load_mnist()
train_data <- mnist[1]
train_target <- mnist[2]
validate_data <- mnist[3]
validate_target <- mnist[4]
mlp_model <- train_nn(train_data, train_target, validate_data,
validate_target, mlp_model, 0.01, 1, 64)

## End(Not run)

simpleMLP documentation built on March 28, 2021, 9:07 a.m.