#' Temporal Fusion Transformer Module
#'
#'
#' @param num_features a list containing the shapes for all necessary information
#' to define the size of layers, including:
#' - `$encoder$past$(num|cat)`: shape of past features
#' - `$encoder$static$(num|cat)`: shape of the static features
#' - `$decoder$target`: shape of the target variable
#' We exclude the batch dimension.
#' @param feature_sizes The number of unique elements for each categorical
#' variable in the dataset.
#' @param hidden_state_size The size of the model shared accross multiple parts
#' of the architecture.
#' @param dropout Dropout rate used in many different places in the network
#' @param num_heads Number of heads in the attention layer.
#' @param num_lstm_layers Number of LSTM layers used in the Locality Enhancement
#' Layer. Usually 2 is good enough.
#' @param num_quantiles the number of quantiles we are predicting for.
#'
#' @export
temporal_fusion_transformer_model <- torch::nn_module(
"temporal_fusion_transformer",
initialize = function(num_features, feature_sizes, hidden_state_size = 100,
dropout = 0.1, num_heads = 4, num_lstm_layers = 2,
num_quantiles = 3) {
self$.check <- torch::nn_parameter(torch::torch_tensor(1, requires_grad = TRUE))
self$preprocessing <- preprocessing(
n_features = num_features,
feature_sizes = feature_sizes,
hidden_state_size = hidden_state_size
)
self$context <- static_context(
n_features = num_features$encoder$static,
hidden_state_size = hidden_state_size
)
self$temporal_selection <- temporal_selection(
n_features = num_features,
hidden_state_size = hidden_state_size
)
self$locality_enhancement <- locality_enhancement_layer(
hidden_state_size = hidden_state_size,
num_layers = num_lstm_layers,
dropout = dropout
)
self$temporal_attn <- temporal_self_attention(
n_heads = num_heads,
hidden_state_size = hidden_state_size,
dropout = dropout
)
self$position_wise <- position_wise_feedforward(
hidden_state_size = hidden_state_size,
dropout = dropout
)
self$output_layer <- quantile_output_layer(
n_quantiles = num_quantiles,
hidden_state_size = hidden_state_size
)
},
forward = function(x) {
# We use entity embeddings [31] for categorical variables as feature representations,
# and linear transformations for continuous variables – transforming each
# input variable into a (dmodel)-dimensional vector which matches the dimensions
# in subsequent layers for skip connections.
transformed <- self$preprocessing(x)
# In contrast with other time series forecasting architectures, the TFT is carefully
# designed to integrate information from static metadata, using separate
# GRN encoders to produce four different context vectors, cs, ce, cc, and ch.
# These contect vectors are wired into various locations in the temporal fusion
# decoder (Sec. 4.5) where static variables play an important role in processing.
context <- self$context(transformed$encoder$static)
# TFT is designed to provide instance-wise variable selection through the use
# of variable selection networks applied to both static covariates and time-dependent
# covariates. Beyond providing insights into which variables are most significant
# for the prediction problem, variable selection also allows TFT to remove any
# unnecessary noisy inputs which could negatively impact performance.
transformed <- self$temporal_selection(transformed, context)
# For instance, [12] adopts a single convolutional layer for locality enhancement
# – extracting local patterns using the same filter across all time. However,
# this might not be suitable for cases when observed inputs exist, due to the
# differing number of past and future inputs. As such, we propose the application
# of a sequence-to-sequence model to naturally handle these differences
transformed <- self$locality_enhancement(transformed, context)
# Besides preserving causal information flow via masking, the self-attention layer
# allows TFT to pick up long-range dependencies that may be challenging for RNN-based
# architectures to learn. Following the self-attention layer, an additional gating
# layer is also applied to facilitate training
attn_output <- self$temporal_attn(transformed)
# we also apply a gated residual connection which skips over the entire transformer
# block, providing a direct path to the sequence-to-sequence layer – yielding a
# simpler model if additional complexity is not required, as shown below
output <- self$position_wise(attn_output, transformed$decoder$known)
# TFT also generates prediction intervals on top of point forecasts. This is
# achieved by the simultaneous prediction of various percentiles (e.g. 10th,
# 50th and 90th) at each time step. Quantile forecasts are generated using linear
# transformation of the output from the temporal fusion decoder
self$output_layer(output)
}
)
quantile_loss <- torch::nn_module(
initialize = function(quantiles) {
self$quantiles <- torch::torch_tensor(sort(quantiles))$unsqueeze(1)$unsqueeze(1)
},
forward = function(y_pred, y_true) {
other <- torch::torch_zeros_like(y_pred)
error <- y_true - y_pred
low_res <- torch::torch_max(error, other = other)
up_res <- torch::torch_max(-error, other = other)
quantiles <- self$quantiles$to(device = y_true$device)
torch::torch_mean(quantiles * low_res + (1 - quantiles) * up_res)
}
)
quantile_output_layer <- torch::nn_module(
initialize = function(n_quantiles, hidden_state_size) {
self$linear <- torch::nn_linear(hidden_state_size, n_quantiles)
},
forward = function(x) {
self$linear(x)
}
)
position_wise_feedforward <- torch::nn_module(
initialize = function(hidden_state_size, dropout = 0) {
self$grn <- time_distributed(gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size,
dropout = dropout
))
self$layer_norm <- torch::nn_layer_norm(
normalized_shape = hidden_state_size
)
self$glu <- gated_linear_unit(
input_size = hidden_state_size,
output_size = hidden_state_size
)
},
forward = function(x, known) {
output <- self$grn(x)
self$layer_norm(known + self$glu(output))
}
)
temporal_self_attention <- torch::nn_module(
initialize = function(n_heads, hidden_state_size, dropout) {
self$multihead_attn <- interpretable_multihead_attention(
n_heads = 3, hidden_state_size = hidden_state_size,
dropout = dropout
)
self$glu <- gated_linear_unit(
input_size = hidden_state_size,
output_size = hidden_state_size
)
self$layer_norm <- torch::nn_layer_norm(
normalized_shape = hidden_state_size
)
},
forward = function(x) {
full_seq <- torch::torch_cat(list(
x$encoder$past,
x$decoder$known
), dim = 2)
attn_output <- self$multihead_attn(x$decoder$known, full_seq, full_seq)
attn_output <- attn_output[,-x$decoder$known$size(2):N,]
self$layer_norm(self$glu(attn_output) + x$decoder$known)
}
)
interpretable_multihead_attention <- torch::nn_module(
initialize = function(n_heads, hidden_state_size, dropout) {
attn_size <- trunc(hidden_state_size / n_heads)
self$query_layers <- seq_len(n_heads) %>%
purrr::map(~torch::nn_linear(hidden_state_size, attn_size)) %>%
torch::nn_module_list()
self$key_layers <- seq_len(n_heads) %>%
purrr::map(~torch::nn_linear(hidden_state_size, attn_size)) %>%
torch::nn_module_list()
self$value_layer <- torch::nn_linear(hidden_state_size, attn_size)
self$output_layer <- torch::nn_linear(attn_size, hidden_state_size)
self$attention <- scaled_dot_product_attention(dropout = dropout)
},
forward = function(q, k, v) {
queries <- purrr::map(as.list(self$query_layers), ~.x(q))
keys <- purrr::map(as.list(self$key_layers), ~.x(k))
value <- self$value_layer(v)
outputs <- purrr::map2(queries, keys, ~self$attention(.x, .y, value))
outputs %>%
torch::torch_stack(dim = 3) %>%
torch::torch_mean(dim = 3) %>%
self$output_layer()
}
)
scaled_dot_product_attention <- torch::nn_module(
initialize = function(dropout = 0) {
self$dropout <- torch::nn_dropout(p = dropout)
self$softmax <- torch::nn_softmax(dim = 3)
},
forward = function(q, k, v, mask = TRUE) {
scaling_factor <- sqrt(k$size(3))
attn <- q %>%
torch::torch_bmm(k$permute(c(1,3,2))) %>%
torch::torch_divide(scaling_factor)
if (mask) {
m <- attn %>%
torch::torch_ones_like() %>%
torch::torch_triu(diagonal = 1 + (attn$size(3) - attn$size(2)))
attn <- attn$masked_fill(m$to(dtype = torch::torch_bool()), -1e9)
}
attn %>%
self$softmax() %>%
self$dropout() %>%
torch::torch_bmm(v)
}
)
static_enrichment_layer <- torch::nn_module(
"static_enrichment_layer",
initialize = function(hidden_state_size, dropout = 0) {
self$grn <- time_distributed(gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size,
dropout = dropout
))
},
forward = function(x, context) {
list(
encoder = list(
past = self$grn(x$encoder$past, context$static_enrichment)
),
decoder = list(
known = self$grn(x$decoder$known, context$static_enrichment)
)
)
}
)
locality_enhancement_layer <- torch::nn_module(
"locality_enhancement_layer",
initialize = function(hidden_state_size, num_layers, dropout = 0) {
dropout <- if (num_layers > 1) dropout else 0
self$encoder <- torch::nn_lstm(
input_size = hidden_state_size,
hidden_size = hidden_state_size,
num_layers = num_layers,
dropout = dropout,
batch_first = TRUE
)
self$decoder <- torch::nn_lstm(
input_size = hidden_state_size,
hidden_size = hidden_state_size,
num_layers = num_layers,
dropout = dropout,
batch_first = TRUE
)
self$encoder_gate <- gated_linear_unit(
input_size = hidden_state_size,
output_size = hidden_state_size
)
self$decoder_gate <- gated_linear_unit(
input_size = hidden_state_size,
output_size = hidden_state_size
)
self$encoder_norm <- torch::nn_layer_norm(
normalized_shape = hidden_state_size
)
self$decoder_norm <- torch::nn_layer_norm(
normalized_shape = hidden_state_size
)
self$num_layers <- num_layers
},
forward = function(x, context) {
c(encoder_output, states) %<-% self$encoder(
input = x$encoder$past,
hx = self$expand_context(context$seq2seq_initial_state)
)
c(decoder_output, .) %<-% self$decoder(
input = x$decoder$known,
hx = states
)
list(
encoder = list(
past = encoder_output %>%
self$encoder_gate() %>%
magrittr::add(x$encoder$past) %>%
self$encoder_norm()
),
decoder = list(
known = decoder_output %>%
self$decoder_gate() %>%
magrittr::add(x$decoder$known) %>%
self$decoder_norm()
)
)
},
expand_context = function(context) {
purrr::map(context, ~.x$expand(c(self$num_layers, -1, -1))$contiguous())
}
)
temporal_selection <- torch::nn_module(
"selection",
initialize = function(n_features, hidden_state_size) {
self$known <- variable_selection_network(
n_features = sum(as.numeric(n_features$decoder$known)),
hidden_state_size = hidden_state_size
)
self$past <- variable_selection_network(
n_features = sum(as.numeric(n_features$encoder$past)),
hidden_state_size = hidden_state_size
)
},
forward = function(x, context) {
x$encoder$past <- x$encoder$past %>%
self$past(context = context$temporal_variable_selection)
x$decoder$known <- x$decoder$known %>%
self$known(context = context$temporal_variable_selection)
x
}
)
static_context <- torch::nn_module(
"static_context",
initialize = function(n_features, hidden_state_size) {
self$static <- variable_selection_network(
n_features = sum(as.numeric(n_features)),
hidden_state_size = hidden_state_size
)
self$temporal_variable_selection <- gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size
)
self$cell_state <- gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size
)
self$hidden_state <- gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size
)
},
forward = function(x) {
selected <- x %>%
torch::torch_unsqueeze(dim = 2) %>%
self$static() %>%
torch::torch_squeeze(dim = 2)
list(
temporal_variable_selection = self$temporal_variable_selection(selected),
seq2seq_initial_state = list(
cell_state = self$cell_state(selected),
hidden_state = self$hidden_state(selected)
)
)
}
)
variable_selection_network <- torch::nn_module(
"variable_selection_network",
initialize = function(n_features, hidden_state_size) {
self$global <- time_distributed(gated_residual_network(
input_size = n_features*hidden_state_size,
output_size = n_features,
hidden_state_size = hidden_state_size
))
self$local <- seq_len(n_features) %>%
purrr::map(~time_distributed(gated_residual_network(
input_size = hidden_state_size,
output_size = hidden_state_size,
hidden_state_size = hidden_state_size
))) %>%
torch::nn_module_list()
},
forward = function(x, context = NULL) {
v <- self$global(x, context = context) %>%
torch::nnf_softmax(dim = 3) %>%
torch::torch_unsqueeze(dim = 4)
x <- x %>%
torch::torch_unbind(dim = 3) %>%
purrr::map2(as.list(self$local), ~.y(.x)) %>%
torch::torch_stack(dim = 3)
torch::torch_sum(v*x, dim = 3)
}
)
gated_residual_network <- torch::nn_module(
"gated_residual_network",
initialize = function(input_size, output_size, hidden_state_size, dropout = 0.1) {
self$input <- torch::nn_linear(input_size, hidden_state_size)
self$context <- torch::nn_linear(hidden_state_size, hidden_state_size, bias = FALSE)
self$hidden <- torch::nn_linear(hidden_state_size, hidden_state_size)
self$dropout <- torch::nn_dropout(dropout)
self$gate <- gated_linear_unit(hidden_state_size, output_size)
self$norm <- torch::nn_layer_norm(output_size)
self$elu <- torch::nn_elu()
if (input_size == output_size) {
self$skip <- torch::nn_identity()
} else {
self$skip <- torch::nn_linear(input_size, output_size)
}
},
forward = function(x, context = NULL) {
if (x$ndim > 2) {
x <- torch::torch_flatten(x, start_dim = 2)
}
skip <- self$skip(x)
x <- self$input(x)
if (!is.null(context)) {
x <- x + self$context(context)
}
hidden <- x %>%
self$elu() %>%
self$hidden() %>%
self$dropout()
self$norm(skip + self$gate(hidden))
}
)
gated_linear_unit <- torch::nn_module(
"gated_linear_unit",
initialize = function(input_size, output_size) {
self$gate <- torch::nn_sequential(
torch::nn_linear(input_size, output_size),
torch::nn_sigmoid()
)
self$activation <- torch::nn_linear(input_size, output_size)
},
forward = function(x) {
self$gate(x) * self$activation(x)
}
)
preprocessing <- torch::nn_module(
"preprocessing",
initialize = function(n_features, feature_sizes, hidden_state_size) {
self$past <- preprocessing_group(
n_features = n_features$encoder$past,
feature_sizes = feature_sizes$past,
hidden_state_size = hidden_state_size
)
self$known <- preprocessing_group(
n_features = n_features$decoder$known,
feature_sizes = feature_sizes$known,
hidden_state_size = hidden_state_size
)
self$static <- preprocessing_group(
n_features = n_features$encoder$static,
feature_sizes = feature_sizes$static,
hidden_state_size = hidden_state_size
)
},
forward = function(x) {
list(
encoder = list(
past = self$past(x$encoder$past),
static = x$encoder$static %>%
purrr::map(~torch::torch_unsqueeze(.x, dim = 2)) %>%
self$static() %>%
torch::torch_squeeze(dim = 2)
),
decoder = list(
known = self$known(x$decoder$known)
)
)
}
)
# Preprocess a group of time-varying variables
#
# Handles both numeric and categorical variables.
# Each numeric variables passes trough a linear transformation that is shared
# accross every time step.
# Categorical variables are represented trough embeddings.
preprocessing_group <- torch::nn_module(
"preprocessing_group",
initialize = function(n_features, feature_sizes, hidden_state_size) {
self$num <- linear_preprocessing(n_features$num, hidden_state_size)
self$cat <- embedding_preprocessing(n_features$cat, feature_sizes, hidden_state_size)
},
forward = function(x) {
x$num <- self$num(x$num)
x$cat <- self$cat(x$cat)
torch::torch_cat(x, dim = 3)
}
)
linear_preprocessing <- torch::nn_module(
"linear_preprocessing",
initialize = function(n_features, hidden_state_size) {
self$linears <- seq_len(n_features) %>%
purrr::map(~time_distributed(torch::nn_linear(1, hidden_state_size))) %>%
torch::nn_module_list()
},
# @param x `Tensor[batch, time_steps, n_features]`
forward = function(x) {
if (x$size(3) == 0) return(NULL)
x %>%
torch::torch_unsqueeze(dim = 4) %>%
torch::torch_unbind(dim = 3) %>%
purrr::imap(~self$linears[[.y]](.x)) %>%
torch::torch_stack(dim = 3)
}
)
embedding_preprocessing <- torch::nn_module(
"embedding_preprocessing",
initialize = function(n_features, feature_sizes, hidden_state_size) {
self$embeddings <- feature_sizes %>%
purrr::map(~time_distributed(torch::nn_embedding(.x, hidden_state_size))) %>%
torch::nn_module_list()
},
# @param x `Tensor[batch, time_steps, n_features]`
forward = function(x) {
if (x$size(3) == 0) return(NULL)
x %>%
torch::torch_unbind(dim = 3) %>%
purrr::imap(~self$embeddings[[.y]](.x)) %>%
torch::torch_stack(dim = 3)
}
)
time_distributed <- torch::nn_module(
"time_distributed",
initialize = function(module) {
self$module <- module
},
forward = function(x, ...) {
extra_args <- list(...)
x %>%
torch::torch_unbind(dim = 2) %>%
purrr::map(~rlang::exec(self$module, .x, !!!extra_args)) %>%
torch::torch_stack(dim = 2)
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.