create_model_transformer: Create transformer model

View source: R/create_model_transformer.R

create_model_transformerR Documentation

Create transformer model

Description

Creates transformer network for classification. Model can consist of several stacked attention blocks.

Usage

create_model_transformer(
  maxlen,
  vocabulary_size = 4,
  embed_dim = 64,
  pos_encoding = "embedding",
  head_size = 4L,
  num_heads = 5L,
  ff_dim = 8,
  dropout = 0,
  n = 10000,
  layer_dense = 2,
  dropout_dense = NULL,
  flatten_method = "flatten",
  last_layer_activation = "softmax",
  loss_fn = "categorical_crossentropy",
  solver = "adam",
  learning_rate = 0.01,
  label_noise_matrix = NULL,
  bal_acc = FALSE,
  f1_metric = FALSE,
  auc_metric = FALSE,
  label_smoothing = 0,
  verbose = TRUE,
  model_seed = NULL,
  mixed_precision = FALSE,
  mirrored_strategy = NULL
)

Arguments

maxlen

Length of predictor sequence.

vocabulary_size

Number of unique character in vocabulary.

embed_dim

Dimension for token embedding. No embedding if set to 0. Should be used when input is not one-hot encoded (integer sequence).

pos_encoding

Either "sinusoid" or "embedding". How to add positional information. If "sinusoid", will add sine waves of different frequencies to input. If "embedding", model learns positional embedding.

head_size

Dimensions of attention key.

num_heads

Number of attention heads.

ff_dim

Units of first dense layer after attention blocks.

dropout

Vector of dropout rates after attention block(s).

n

Frequency of sine waves for positional encoding. Only applied if pos_encoding = "sinusoid".

layer_dense

Vector specifying number of neurons per dense layer after last LSTM or CNN layer (if no LSTM used).

dropout_dense

Dropout for dense layers.

flatten_method

How to process output of last attention block. Can be "max_ch_first", "max_ch_last", "average_ch_first", "average_ch_last", "both_ch_first", "both_ch_last", "all", "none" or "flatten". If "average_ch_last" / "max_ch_last" or "average_ch_first" / "max_ch_first", will apply global average/max pooling. ⁠_ch_first⁠ / ⁠_ch_last⁠ to decide along which axis. "both_ch_first" / "both_ch_last" to use max and average together. "all" to use all 4 global pooling options together. If "flatten", will flatten output after last attention block. If "none" no flattening applied.

last_layer_activation

Activation function of output layer(s). For example "sigmoid" or "softmax".

loss_fn

Either "categorical_crossentropy" or "binary_crossentropy". If label_noise_matrix given, will use custom "noisy_loss".

solver

Optimization method, options are ⁠"adam", "adagrad", "rmsprop"⁠ or "sgd".

learning_rate

Learning rate for optimizer.

label_noise_matrix

Matrix of label noises. Every row stands for one class and columns for percentage of labels in that class. If first label contains 5 percent wrong labels and second label no noise, then

label_noise_matrix <- matrix(c(0.95, 0.05, 0, 1), nrow = 2, byrow = TRUE )

bal_acc

Whether to add balanced accuracy.

f1_metric

Whether to add F1 metric.

auc_metric

Whether to add AUC metric.

label_smoothing

Float in [0, 1]. If 0, no smoothing is applied. If > 0, loss between the predicted labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5. The closer the argument is to 1 the more the labels get smoothed.

verbose

Boolean.

model_seed

Set seed for model parameters in tensorflow if not NULL.

mixed_precision

Whether to use mixed precision (https://www.tensorflow.org/guide/mixed_precision).

mirrored_strategy

Whether to use distributed mirrored strategy. If NULL, will use distributed mirrored strategy only if >1 GPU available.

Value

A keras model implementing transformer architecture.

Examples


maxlen <- 50

library(keras)
model <- create_model_transformer(maxlen = maxlen,
                                  head_size=c(10,12),
                                  num_heads=c(7,8),
                                  ff_dim=c(5,9),
                                  dropout=c(0.3, 0.5))


GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.