nn_gru: Applies a multi-layer gated recurrent unit (GRU) RNN to an...

Description Usage Arguments Details Inputs Outputs Attributes Note Examples

Description

For each element in the input sequence, each layer computes the following function:

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
nn_gru(
  input_size,
  hidden_size,
  num_layers = 1,
  bias = TRUE,
  batch_first = FALSE,
  dropout = 0,
  bidirectional = FALSE,
  ...
)

Arguments

input_size

The number of expected features in the input x

hidden_size

The number of features in the hidden state h

num_layers

Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1

bias

If FALSE, then the layer does not use bias weights b_ih and b_hh. Default: TRUE

batch_first

If TRUE, then the input and output tensors are provided as (batch, seq, feature). Default: FALSE

dropout

If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0

bidirectional

If TRUE, becomes a bidirectional GRU. Default: FALSE

...

currently unused.

Details

\begin{array}{ll} r_t = σ(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = σ(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) n_t + z_t h_{(t-1)} \end{array}

where h_t is the hidden state at time t, x_t is the input at time t, h_{(t-1)} is the hidden state of the previous layer at time t-1 or the initial hidden state at time 0, and r_t, z_t, n_t are the reset, update, and new gates, respectively. σ is the sigmoid function.

Inputs

Inputs: input, h_0

Outputs

Outputs: output, h_n

Attributes

Note

All the weights and biases are initialized from \mathcal{U}(-√{k}, √{k}) where k = \frac{1}{\mbox{hidden\_size}}

Examples

1
2
3
4
5
6
7
8
if (torch_is_installed()) {

rnn <- nn_gru(10, 20, 2)
input <- torch_randn(5, 3, 10)
h0 <- torch_randn(2, 3, 20)
output <- rnn(input, h0)

}

torch documentation built on Oct. 7, 2021, 9:22 a.m.