mlr_learners.ft_transformer: FT-Transformer

mlr_learners.ft_transformerR Documentation

FT-Transformer

Description

Feature-Tokenizer Transformer for tabular data that can either work on lazy_tensor inputs or on standard tabular features.

Some differences from the paper implementation: no attention compression, no option to have prenormalization in the first layer.

If training is unstable, consider a combination of standardizing features (e.g. using po("scale")), using an adaptive optimizer (e.g. Adam), reducing the learning rate, and using a learning rate scheduler (see CallbackSetLRScheduler for options).

Dictionary

This Learner can be instantiated using the sugar function lrn():

lrn("classif.ft_transformer", ...)
lrn("regr.ft_transformer", ...)

Properties

  • Supported task types: 'classif', 'regr'

  • Predict Types:

    • classif: 'response', 'prob'

    • regr: 'response'

  • Feature Types: “logical”, “integer”, “numeric”, “factor”, “ordered”, “lazy_tensor”

  • Required Packages: mlr3, mlr3torch, torch

Parameters

Parameters from LearnerTorch and PipeOpTorchFTTransformerBlock, as well as:

  • n_blocks :: integer(1)
    The number of transformer blocks.

  • d_token :: integer(1)
    The dimension of the embedding.

  • cardinalities :: integer(1)
    The number of categories for each categorical feature. This only needs to be specified when working with lazy_tensor inputs.

  • init_token :: character(1)
    The initialization method for the embedding weights. Either "uniform" or "normal". "Uniform" by default.

  • ingress_tokens :: named list() or NULL
    A list of TorchIngressTokens. Only required when using lazy tensor features. The names are either "num.input" or "categ.input", and the values are lazy tensor ingress tokens constructed by, e.g. ⁠ingress_ltnsr(<num_feat_name>)⁠.

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchFTTransformer

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
LearnerTorchFTTransformer$new(
  task_type,
  optimizer = NULL,
  loss = NULL,
  callbacks = list()
)
Arguments
task_type

(character(1))
The task type, either ⁠"classif⁠" or "regr".

optimizer

(TorchOptimizer)
The optimizer to use for training. Per default, adam is used.

loss

(TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.

callbacks

(list() of TorchCallbacks)
The callbacks. Must have unique ids.


Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerTorchFTTransformer$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

References

Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). “Revisiting Deep Learning for Tabular Data.” arXiv, 2106.11959.

See Also

Other Learner: mlr_learners.mlp, mlr_learners.module, mlr_learners.tab_resnet, mlr_learners.torch_featureless, mlr_learners_torch, mlr_learners_torch_image, mlr_learners_torch_model

Examples


# Define the Learner and set parameter values
learner = lrn("classif.ft_transformer")
learner$param_set$set_values(
  epochs = 1, batch_size = 16, device = "cpu",
  n_blocks = 2, d_token = 32, ffn_d_hidden_multiplier = 4/3
)

# Define a Task
task = tsk("iris")

# Create train and test set
ids = partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()


mlr3torch documentation built on Aug. 26, 2025, 5:09 p.m.