positional_encoding <- torch::nn_module(
"positional_encoding",
initialize = function(d_model, dropout = 0.1, max_len = 5000L) {
self$dropout <- torch::nn_dropout(p = dropout)
self$max_len <- max_len
self$d_model <- d_model
position <- torch::torch_arange(start = 1, max_len)$unsqueeze(2)
div_term <- torch::torch_exp(torch::torch_arange(1, d_model, 2) * (-log(1e5) / d_model))
pe <- torch::torch_zeros(1, max_len, d_model, device = self$config$device)
pe[1, , 1:N:2] <- torch::torch_sin(position * div_term)
pe[1, , 2:N:2] <- torch::torch_cos(position * div_term)
self$pe <- pe
},
forward = function() {
x <- self$pe[1, 1:self$max_len]
self$dropout(x)$unsqueeze(1)
}
)
resnet_feature_extractor <- torch::nn_module(
"resnet_feature_extractor",
initialize = function(config) {
# use ResNet model for visual features embedding (remove classificaion head)
# extract resnet50 `layer 4`
self$resnet50 <- torch::nn_prune_head(torchvision::model_resnet50(pretrain = TRUE), 2)
# Applying convolution and linear layer
self$conv1 <- torch::nn_conv2d(2048, config$hidden_size, kernel_size = 1)
self$relu1 <- torch::nn_relu()
self$linear1 <- torch::nn_linear(config$hidden_size %/% config$intermediate_ff_size_factor, config$max_position_embeddings)
# TODO adapt the output according to https://github.com/microsoft/unilm/blob/9865272c76829757b13292f1b51d2fcd7b5fa401/layoutlmft/layoutlmft/models/layoutlmv2/modeling_layoutlmv2.py#L601
},
forward = function(x) {
x <- self$resnet50(x / 255)
x <- self$conv1(x)
x <- self$relu1(x)
y <- x$reshape(c(x$shape[1:2], -1)) # "b e wl hl -> b e (wl.hl)" batch, embedding, width_low, height_low, wl*hl=192
y <- self$linear1(y)
y <- y$permute(c(1, 3, 2)) # "b e s -> b s e", batch, embedding, sequence, movedim is 0-indexed
# TODO adapt the output according to https://github.com/microsoft/unilm/blob/9865272c76829757b13292f1b51d2fcd7b5fa401/layoutlmft/layoutlmft/models/layoutlmv2/modeling_layoutlmv2.py#L601
return(y)
}
)
docformer_embeddings <- torch::nn_module(
"docformer_embeddings",
initialize = function(config) {
self$config <- config
max_2d_p_emb <- config$max_2d_position_embeddings
rel_max_2d_p_emb <- 2 * max_2d_p_emb + 1
# self$word_embedding <- torch::nn_embedding(config$vocab_size, config$hidden_size, padding_idx = config$pad_token_id)
self$position_embedding_v <- positional_encoding(d_model = config$hidden_size, dropout = 0.1, max_len = config$max_position_embeddings)
self$x_topleft_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$x_bottomright_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$w_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$shape_size)
self$x_topleft_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$x_bottomright_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$x_centroid_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_topleft_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$y_bottomright_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$h_pos_embeddings_v <- torch::nn_embedding(max_2d_p_emb, config$shape_size)
self$y_topleft_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_bottomright_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_centroid_dist_to_prev_embeddings_v <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$position_embedding_t <- positional_encoding(d_model = config$hidden_size, dropout = 0.1, max_len = config$max_position_embeddings)
self$x_topleft_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$x_bottomright_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$w_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$shape_size)
self$x_topleft_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$x_bottomright_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$x_centroid_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_topleft_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$y_bottomright_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$coordinate_size)
self$h_pos_embeddings_t <- torch::nn_embedding(max_2d_p_emb, config$shape_size)
self$y_topleft_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_bottomright_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
self$y_centroid_dist_to_prev_embeddings_t <- torch::nn_embedding(rel_max_2d_p_emb, config$shape_size)
},
forward = function(x_feature, y_feature) {
# Arguments:
# x_features of shape (batch_size, seq_len, 6)
# y_features of shape (batch_size, seq_len, 6)
#
# Outputs:
#
# (V-bar-s, T-bar-s) of shape (batch_size, 512,768),(batch_size, 512,768)
#
# What are the 6 features:
#
# 1 -> top left x/y
# 2 -> bottom right x/y
# 3 -> width/height
# 4 -> diff top left x/y
# 5 -> diff bottom right x/y
# 6 -> centroids diff x/y
batch <- x_feature$shape[1]
seq_len <- x_feature$shape[2]
num_feat <- x_feature$shape[3] # 6
hidden_size <- self$config$hidden_size
sub_dim <- hidden_size %/% num_feat
mbox_max <- self$config$max_2d_position_embeddings
# Clamp and add a bias for handling negative values, workaround to support float16()
x_feature[, , 1:3] <- x_feature[, , 1:3]$clamp(1L, mbox_max)
y_feature[, , 1:3] <- y_feature[, , 1:3]$clamp(1L, mbox_max)
x_feature[, , 4:N] <- x_feature[, , 4:N]$clamp(-mbox_max, mbox_max) + mbox_max
y_feature[, , 4:N] <- y_feature[, , 4:N]$clamp(-mbox_max, mbox_max) + mbox_max
# x_feature[, , 4:N] <- x_feature[, , 4:N]$max(-mbox_max)$min(mbox_max) + mbox_max
# y_feature[, , 4:N] <- y_feature[, , 4:N]$max(-mbox_max)$min(mbox_max) + mbox_max
x_topleft_pos_embeddings_v <- self$x_topleft_pos_embeddings_v(x_feature[, , 1])
x_bottomright_pos_embeddings_v <- self$x_bottomright_pos_embeddings_v(x_feature[, , 2])
w_pos_embeddings_v <- self$w_pos_embeddings_v(x_feature[, , 3])
x_topleft_dist_to_prev_embeddings_v <- self$x_topleft_dist_to_prev_embeddings_v(x_feature[, , 4])
x_bottomright_dist_to_prev_embeddings_v <- self$x_bottomright_dist_to_prev_embeddings_v(x_feature[, , 5])
x_centroid_dist_to_prev_embeddings_v <- self$x_centroid_dist_to_prev_embeddings_v(x_feature[, , 6])
x_calculated_embedding_v <- torch::torch_cat(
c(x_topleft_pos_embeddings_v,
x_bottomright_pos_embeddings_v,
w_pos_embeddings_v,
x_topleft_dist_to_prev_embeddings_v,
x_bottomright_dist_to_prev_embeddings_v,
x_centroid_dist_to_prev_embeddings_v),
dim = -1
)
y_topleft_pos_embeddings_v <- self$y_topleft_pos_embeddings_v(y_feature[, , 1])
y_bottomright_pos_embeddings_v <- self$y_bottomright_pos_embeddings_v(y_feature[, , 2])
h_pos_embeddings_v <- self$h_pos_embeddings_v(y_feature[,,3])
y_topleft_dist_to_prev_embeddings_v <- self$y_topleft_dist_to_prev_embeddings_v(y_feature[, , 4])
y_bottomright_dist_to_prev_embeddings_v <- self$y_bottomright_dist_to_prev_embeddings_v(y_feature[, , 5])
y_centroid_dist_to_prev_embeddings_v <- self$y_centroid_dist_to_prev_embeddings_v(y_feature[, , 6])
y_calculated_embedding_v <- torch::torch_cat(
c(y_topleft_pos_embeddings_v,
y_bottomright_pos_embeddings_v,
h_pos_embeddings_v,
y_topleft_dist_to_prev_embeddings_v,
y_bottomright_dist_to_prev_embeddings_v,
y_centroid_dist_to_prev_embeddings_v),
dim = -1
)
v_feat_s <- x_calculated_embedding_v + y_calculated_embedding_v + self$position_embedding_v()
x_topleft_pos_embeddings_t <- self$x_topleft_pos_embeddings_t(x_feature[, , 1])
x_bottomright_pos_embeddings_t <- self$x_bottomright_pos_embeddings_t(x_feature[, , 2])
w_pos_embeddings_t <- self$w_pos_embeddings_t(x_feature[, , 3])
x_topleft_dist_to_prev_embeddings_t <- self$x_topleft_dist_to_prev_embeddings_t(x_feature[, , 4])
x_bottomright_dist_to_prev_embeddings_t <- self$x_bottomright_dist_to_prev_embeddings_t(x_feature[, , 5])
x_centroid_dist_to_prev_embeddings_t <- self$x_centroid_dist_to_prev_embeddings_t(x_feature[, , 6])
x_calculated_embedding_t <- torch::torch_cat(
c(x_topleft_pos_embeddings_t,
x_bottomright_pos_embeddings_t,
w_pos_embeddings_t,
x_topleft_dist_to_prev_embeddings_t,
x_bottomright_dist_to_prev_embeddings_t,
x_centroid_dist_to_prev_embeddings_t),
dim = -1
)
y_topleft_pos_embeddings_t <- self$y_topleft_pos_embeddings_t(y_feature[, , 1])
y_bottomright_pos_embeddings_t <- self$y_bottomright_pos_embeddings_t(y_feature[, , 2])
h_pos_embeddings_t <- self$h_pos_embeddings_t(y_feature[,,3])
y_topleft_dist_to_prev_embeddings_t <- self$y_topleft_dist_to_prev_embeddings_t(y_feature[, , 4])
y_bottomright_dist_to_prev_embeddings_t <- self$y_bottomright_dist_to_prev_embeddings_t(y_feature[, , 5])
y_centroid_dist_to_prev_embeddings_t <- self$y_centroid_dist_to_prev_embeddings_t(y_feature[, , 6])
y_calculated_embedding_t <- torch::torch_cat(
c(y_topleft_pos_embeddings_t,
y_bottomright_pos_embeddings_t,
h_pos_embeddings_t,
y_topleft_dist_to_prev_embeddings_t,
y_bottomright_dist_to_prev_embeddings_t,
y_centroid_dist_to_prev_embeddings_t),
dim = -1
)
t_feat_s <- x_calculated_embedding_t + y_calculated_embedding_t + self$position_embedding_t()
return(list(v_feat_s, t_feat_s))
}
)
pre_norm <- torch::nn_module(
"pre_norm",
initialize = function(dim, fn) {
self$norm <- torch::nn_layer_norm(dim)
self$fn <- fn
},
forward = function(x) {
return(self$fn(self$norm(x)))
}
)
pre_norm_attention <- torch::nn_module(
"pre_norm_attention",
initialize = function(dim, fn) {
self$norm_t_feat <- torch::nn_layer_norm(dim)
self$norm_v_feat <- torch::nn_layer_norm(dim)
self$norm_t_feat_s <- torch::nn_layer_norm(dim)
self$norm_v_feat_s <- torch::nn_layer_norm(dim)
self$fn <- fn
},
forward = function(text_feat, img_feat, text_spatial_feat, img_spatial_feat) {
return(self$fn(
self$norm_t_feat(text_feat),
self$norm_v_feat(img_feat),
self$norm_t_feat_s(text_spatial_feat),
self$norm_v_feat_s(img_spatial_feat))
)
}
)
feed_forward <- torch::nn_module(
"feed_forward",
initialize = function(dim, hidden_dim, dropout = 0) {
self$net <- torch::nn_sequential(
torch::nn_linear(dim, hidden_dim),
torch::nn_gelu(),
torch::nn_dropout(dropout),
torch::nn_linear(hidden_dim, dim),
torch::nn_dropout(dropout)
)
},
forward = function(x) {
self$net(x)
}
)
relative_position <- torch::nn_module(
"relative_position",
initialize = function(num_units, max_relative_position, max_seq_length) {
self$num_units <- num_units
self$max_relative_position <- max_relative_position
self$embeddings_table <- torch::nn_parameter(
torch::torch_zeros(c(max_relative_position * 2 + 1, num_units), device = self$config$device)
)
torch::nn_init_xavier_uniform_(self$embeddings_table)
self$max_length <- max_seq_length
range_vec_q <- torch::torch_arange(1, max_seq_length)
range_vec_k <- torch::torch_arange(1, max_seq_length)
distance_mat <- range_vec_k[NULL, ..] - range_vec_q[.., NULL]
distance_mat_clipped <- torch::torch_clamp(distance_mat, -self$max_relative_position, self$max_relative_position)
final_mat <- distance_mat_clipped + self$max_relative_position + 1
self$final_mat <- final_mat$to(torch::torch_long())
},
forward = function(length_q, length_k) {
self$embeddings_table[self$final_mat[1:length_q, 1:length_k]]
}
)
multimodal_attention_layer <- torch::nn_module(
"multimodal_attention_layer",
initialize = function(embed_dim, n_heads, max_relative_position, max_seq_length, dropout) {
self$embed_dim <- embed_dim
self$n_heads <- n_heads
self$head_dim <- embed_dim %/% n_heads
self$relative_positions_text <- relative_position(self$head_dim, max_relative_position, max_seq_length)
self$relative_positions_img <- relative_position(self$head_dim, max_relative_position, max_seq_length)
# text qkv embeddings
self$fc_k_text <- torch::nn_linear(embed_dim, embed_dim)
self$fc_q_text <- torch::nn_linear(embed_dim, embed_dim)
self$fc_v_text <- torch::nn_linear(embed_dim, embed_dim)
# image qkv embeddings
self$fc_k_img <- torch::nn_linear(embed_dim, embed_dim)
self$fc_q_img <- torch::nn_linear(embed_dim, embed_dim)
self$fc_v_img <- torch::nn_linear(embed_dim, embed_dim)
# spatial qk embeddings (shared for visual and text)
self$fc_k_spatial <- torch::nn_linear(embed_dim, embed_dim)
self$fc_q_spatial <- torch::nn_linear(embed_dim, embed_dim)
self$softmax_dropout <- torch::nn_sequential(
torch::nn_softmax(dim = -1),
torch::nn_dropout(dropout)
)
self$to_out <- torch::nn_sequential(
torch::nn_linear(embed_dim, embed_dim),
torch::nn_dropout(dropout)
)
self$scale <- torch::torch_sqrt(embed_dim)
},
forward = function(text_feat, img_feat, text_spatial_feat, img_spatial_feat) {
seq_length <- text_feat$shape[2]
# self attention of text
# b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
# 'b t (head k) -> head b t k'
key_text_nh_bthk <- self$fc_k_text(text_feat)$unsqueeze(4)
dim <- key_text_nh_bthk$shape
key_text_nh <- key_text_nh_bthk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4))
# 'b l (head k) -> head b l k'
query_text_nh_blhk <- self$fc_q_text(text_feat)$unsqueeze(4)
dim <- query_text_nh_blhk$shape
query_text_nh <- query_text_nh_blhk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4))
# 'b t (head k) -> head b t k'
value_text_nh_bthk <- self$fc_v_text(text_feat)$unsqueeze(4)
dim <- value_text_nh_bthk$shape
value_text_nh <- value_text_nh_bthk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4))
dots_text <- torch::torch_einsum("hblk,hbtk->hblt", list(query_text_nh, key_text_nh)) / self$scale
# 1D relative positions (query, key)
rel_pos_embed_text <- self$relative_positions_text(seq_length, seq_length)
rel_pos_key_text <- torch::torch_einsum("bhrd,lrd->bhlr", list(key_text_nh, rel_pos_embed_text))
rel_pos_query_text <- torch::torch_einsum("bhld,lrd->bhlr", list(query_text_nh, rel_pos_embed_text))
# shared spatial <-> text hidden features
key_spatial_text <- self$fc_k_spatial(text_spatial_feat)$unsqueeze(4)
dim <- key_spatial_text$shape
key_spatial_text_nh <- key_spatial_text$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b t (head k) -> head b t k'
query_spatial_text <- self$fc_q_spatial(text_spatial_feat)$unsqueeze(4)
dim <- query_spatial_text$shape
query_spatial_text_nh <- query_spatial_text$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b l (head k) -> head b l k'
dots_text_spatial <- torch::torch_einsum("hblk,hbtk->hblt", list(query_spatial_text_nh, key_spatial_text_nh)) / self$scale
# Line 38 of pseudo-code
text_attn_scores <- dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial
# self-attention of image
key_img_bthk <- self$fc_k_img(img_feat)$unsqueeze(4)
dim <- key_img_bthk$shape
key_img_nh <- key_img_bthk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b t (head k) -> head b t k'
query_img_blhk <- self$fc_q_img(img_feat)$unsqueeze(4)
dim <- query_img_blhk$shape
query_img_nh <- query_img_blhk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b l (head k) -> head b l k'
value_img_bthk <- self$fc_v_img(img_feat)$unsqueeze(4) # 'b t (head k) -> head b t k'
dim <- value_img_bthk$shape
value_img_nh <- value_img_bthk$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b t (head k) -> head b t k'
dots_img <- torch::torch_einsum("hblk,hbtk->hblt", list(query_img_nh, key_img_nh)) / self$scale
# 1D relative positions (query, key)
rel_pos_embed_img <- self$relative_positions_img(seq_length, seq_length)
rel_pos_key_img <- torch::torch_einsum("bhrd,lrd->bhlr", list(key_img_nh, rel_pos_embed_text))
rel_pos_query_img <- torch::torch_einsum("bhld,lrd->bhlr", list(query_img_nh, rel_pos_embed_text))
# shared spatial <-> image features
key_spatial_img <- self$fc_k_spatial(img_spatial_feat)$unsqueeze(4)
dim <- key_spatial_img$shape
key_spatial_img_nh <- key_spatial_img$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b t (head k) -> head b t k'
query_spatial_img <- self$fc_q_spatial(img_spatial_feat)$unsqueeze(4)
dim <- query_spatial_img$shape
query_spatial_img_nh <- query_spatial_img$reshape(c(dim[1:2], self$n_heads, -1))$permute(c(3, 1, 2, 4)) # 'b l (head k) -> head b l k'
dots_img_spatial <- torch::torch_einsum("hblk,hbtk->hblt", list(query_spatial_img_nh, key_spatial_img_nh)) / self$scale
# Line 59 of pseudo-code
img_attn_scores <- dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial
text_attn_probs <- self$softmax_dropout(text_attn_scores)
img_attn_probs <- self$softmax_dropout(img_attn_scores)
text_context <- torch::torch_einsum("hblt,hbtv->hblv", list(text_attn_probs, value_text_nh))
img_context <- torch::torch_einsum("hblt,hbtv->hblv", list(img_attn_probs, value_img_nh))
context <- text_context + img_context
dim <- context$shape
embeddings <- context$permute(c(2, 3, 1, 4))$reshape(c(dim[2:3], -1, 1))$squeeze(4) # 'head b t d -> b t (head d)')
return(self$to_out(embeddings))
}
)
docformer_encoder <- torch::nn_module(
"docformer_encoder",
initialize = function(config) {
self$config <- config
hidden_size <- config$hidden_size
self$layers <- torch::nn_module_list()
for (i in seq(config$num_hidden_layers)) {
encoder_block <- torch::nn_module_list(list(
pre_norm_attention(hidden_size,
multimodal_attention_layer(hidden_size,
config$num_attention_heads,
config$max_relative_positions,
config$max_position_embeddings,
config$hidden_dropout_prob
)
),
pre_norm(hidden_size,
feed_forward(hidden_size,
config$intermediate_size,
dropout = config$hidden_dropout_prob))
))
self$layers$append(encoder_block)
}
},
forward = function(text_feat, # text feat or output from last encoder block
img_feat,
text_spatial_feat,
img_spatial_feat) {
for (id in seq_along(self$layers)) {
skip <- text_feat + img_feat + text_spatial_feat + img_spatial_feat
attn <- self$layers[[id]][[1]]
ff <- self$layers[[id]][[2]]
x <- attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip
x <- ff(x) + x
text_feat <- x
}
return(x)
}
)
language_feature_extractor <- torch::nn_module(
"language_feature_extractor",
initialize = function(config) {
layoutlm_net <- LayoutLMForTokenClassification(config)$from_pretrained(config$pretrained_model_name)
self$embedding_vector <- torch::nn_embedding(config$vocab_size, config$hidden_size, .weight = layoutlm_net$layoutlm$embeddings$word_embeddings$weight)
},
forward = function(x) {
# shift text idx values to be 1-indexed
return(self$embedding_vector(x + 1L)$squeeze(3))
}
)
extract_features <- torch::nn_module(
"extract_features",
initialize = function(config) {
self$visual_feature <- resnet_feature_extractor(config)
self$language_feature <- language_feature_extractor(config)
self$spatial_feature <- docformer_embeddings(config)
},
forward = function(encoding) {
v_feat <- self$visual_feature(encoding$image)
t_feat <- self$language_feature(encoding$text)
v_feat_s_t_feat_s <- self$spatial_feature(encoding$x_features, encoding$y_features)
return(list(v_feat, t_feat, v_feat_s_t_feat_s[[1]], v_feat_s_t_feat_s[[2]]))
}
)
docformer <- torch::nn_module(
"docformer",
initialize = function(config) {
self$config <- config
self$extract_feature <- extract_features(config)
self$encoder <- docformer_encoder(config)
self$dropout <- torch::nn_dropout(config$hidden_dropout_prob)
},
forward = function(x) {
x_ex_fe <- self$extract_feature(x)
x_enc <- self$encoder(x_ex_fe[[1]], x_ex_fe[[2]], x_ex_fe[[3]], x_ex_fe[[4]])
output <- self$dropout(x_enc)
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.