mlr_learners_classif.tabpfn: TabPFN Classification Learner

mlr_learners_classif.tabpfnR Documentation

TabPFN Classification Learner

Description

Foundation model for tabular data. Uses reticulate to interface with the tabpfn Python package.

Installation

While the Python dependencies are handled via reticulate::py_require(), you can manually specify a virtual environment by calling reticulate::use_virtualenv() prior to calling the ⁠$train()⁠ function. In this virtual environment, the tabpfn package and its dependencies must be installed.

Saving a Learner

In order to save a LearnerClassifTabPFN for later usage, it is necessary to call the ⁠$marshal()⁠ method on the Learner before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerClassifTabPFN into R again, you then need to call ⁠$unmarshal()⁠ to transform it into a useable state.

Initial parameter values

  • n_jobs is initialized to 1 to avoid threading conflicts with future.

Custom mlr3 parameters

  • categorical_feature_indices uses R indexing instead of zero-based Python indexing.

  • device must be a string. If set to "auto", the behavior is the same as original. Otherwise, the string is passed as argument to torch.device() to create a device.

  • inference_precision must be "auto" or "autocast". Passing torch.dtype is currently not supported.

  • inference_config is currently not supported.

Dictionary

This Learner can be instantiated via lrn():

lrn("classif.tabpfn")

Meta Information

  • Task type: “classif”

  • Predict Types: “response”, “prob”

  • Feature Types: “logical”, “integer”, “numeric”

  • Required Packages: mlr3

Parameters

Id Type Default Levels Range
n_estimators integer 4 [1, \infty)
categorical_features_indices untyped - -
softmax_temperature numeric 0.9 [0, \infty)
balance_probabilities logical FALSE TRUE, FALSE -
average_before_softmax logical FALSE TRUE, FALSE -
model_path untyped "auto" -
device untyped "auto" -
ignore_pretraining_limits logical FALSE TRUE, FALSE -
inference_precision character auto auto, autocast -
fit_mode character fit_preprocessors low_memory, fit_preprocessors, fit_with_cache -
memory_saving_mode untyped "auto" -
random_state integer 0 (-\infty, \infty)
n_jobs integer - [1, \infty)

Super classes

mlr3::Learner -> mlr3::LearnerClassif -> LearnerClassifTabPFN

Active bindings

marshaled

(logical(1))
Whether the learner has been marshaled.

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
LearnerClassifTabPFN$new()

Method marshal()

Marshal the learner's model.

Usage
LearnerClassifTabPFN$marshal(...)
Arguments
...

(any)
Additional arguments passed to marshal_model().


Method unmarshal()

Unmarshal the learner's model.

Usage
LearnerClassifTabPFN$unmarshal(...)
Arguments
...

(any)
Additional arguments passed to unmarshal_model().


Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerClassifTabPFN$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Author(s)

b-zhou

References

Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. \Sexpr[results=rd]{tools:::Rd_expr_doi("10.1038/s41586-024-08328-6")}, https://www.nature.com/articles/s41586-024-08328-6.

Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.

See Also

Examples


# Define the Learner
learner = mlr3::lrn("classif.tabpfn")
print(learner)

# Define a Task
task = mlr3::tsk("sonar")

# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)


# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()


mlr-org/mlr3extralearners documentation built on June 11, 2025, 7:06 p.m.