keras_rnn: Train Recurrent Neural Network with Keras

Description Usage Arguments Value References See Also Examples

View source: R/keras_rnn.R

Description

Currently supports "simple unit", "gated recurrent unit" (GRU) and "Long-Short Term Memory" (LSTM) using Keras framework (with TensorFlow backend)

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
keras_rnn(
  X,
  Y,
  model_type,
  tsteps,
  n_epochs = 200,
  n_units = 32,
  loss = "mse",
  metrics = NULL,
  dropout_in_test = FALSE,
  optimizer = optimizer_rmsprop(),
  dropout = 0,
  recurrent_dropout = 0,
  history = FALSE,
  live_plot = FALSE
)

Arguments

X

list of "train", "val", and "test" with 3D (keras) arrays

Y

list of "train", "val", and "test" with 2D (keras) arrays

model_type

One of "simple", "gru" and "lstm"

tsteps

number of time steps for keras input shape

n_epochs

default 200

n_units

32 (currently fixed)

loss

default "mse"

metrics

default NULL

dropout_in_test

apply dropout during training only (default) or during testing also? Required for dropout-based prediction intervals (bayesian RNN)

optimizer

from keras, e.g. optimizer_rmsprop

dropout

dropout rate

recurrent_dropout

Dropout rate applied to reccurent layer. Default 0

history

in addition to model, return model history? Beware that output changes from model to list(model, history) if history = TRUE

live_plot

plot loss and validation metric during training? False by default

Value

Keras model by default (history = FALSE) else list with Keras model and history

References

See Also

Keras Documentation

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
data <- tsRNN::DT_apple
data[, value_lag1 := data.table::shift(value, type = "lag", n = 1)]
data <- data[!is.na(get(paste0("value_lag1")))]

nn_arrays <- ts_nn_preparation(data, tsteps = 1L, length_val = 6L, length_test = 6L)
keras_rnn(nn_arrays$x, nn_arrays$y, model_type = "simple", tsteps = 1, n_epochs = 20)

# return model and history
result <- keras_rnn(
  nn_arrays$x, nn_arrays$y, model_type = "simple", tsteps = 1, n_epochs = 20, history = TRUE
)

result$model
result$history

## Not run: 
# Plot result
plot(result$history)

## End(Not run)

thfuchs/tsRNN documentation built on April 17, 2021, 11:03 p.m.