layer_attention: Dot-product attention layer, a.k.a. Luong-style attention

View source: R/layer-attention.R

layer_attentionR Documentation

Dot-product attention layer, a.k.a. Luong-style attention

Description

Dot-product attention layer, a.k.a. Luong-style attention

Usage

layer_attention(
  inputs,
  use_scale = FALSE,
  score_mode = "dot",
  ...,
  dropout = NULL
)

Arguments

inputs

List of the following tensors:

  • query: Query Tensor of shape ⁠[batch_size, Tq, dim]⁠.

  • value: Value Tensor of shape ⁠[batch_size, Tv, dim]⁠.

  • key: Optional key Tensor of shape ⁠[batch_size, Tv, dim]⁠. If not given, will use value for both key and value, which is the most common case.

use_scale

If TRUE, will create a scalar variable to scale the attention scores.

score_mode

Function to use to compute attention scores, one of ⁠{"dot", "concat"}⁠. "dot" refers to the dot product between the query and key vectors. "concat" refers to the hyperbolic tangent of the concatenation of the query and key vectors.

...

standard layer arguments (e.g., batch_size, dtype, name, trainable, weights)

dropout

Float between 0 and 1. Fraction of the units to drop for the attention scores. Defaults to 0.0.

Details

inputs are query tensor of shape ⁠[batch_size, Tq, dim]⁠, value tensor of shape ⁠[batch_size, Tv, dim]⁠ and key tensor of shape ⁠[batch_size, Tv, dim]⁠. The calculation follows the steps:

  1. Calculate scores with shape ⁠[batch_size, Tq, Tv]⁠ as a query-key dot product: scores = tf$matmul(query, key, transpose_b=TRUE).

  2. Use scores to calculate a distribution with shape ⁠[batch_size, Tq, Tv]⁠: distribution = tf$nn$softmax(scores).

  3. Use distribution to create a linear combination of value with shape ⁠[batch_size, Tq, dim]⁠: return tf$matmul(distribution, value).

See Also

Other core layers: layer_activation(), layer_activity_regularization(), layer_dense(), layer_dense_features(), layer_dropout(), layer_flatten(), layer_input(), layer_lambda(), layer_masking(), layer_permute(), layer_repeat_vector(), layer_reshape()


keras documentation built on May 29, 2024, 3:20 a.m.