tabnet: Parsnip compatible tabnet model

View source: R/parsnip.R

tabnetR Documentation

Parsnip compatible tabnet model

Description

Parsnip compatible tabnet model

Usage

tabnet(
  mode = "unknown",
  epochs = NULL,
  penalty = NULL,
  batch_size = NULL,
  learn_rate = NULL,
  decision_width = NULL,
  attention_width = NULL,
  num_steps = NULL,
  feature_reusage = NULL,
  virtual_batch_size = NULL,
  num_independent = NULL,
  num_shared = NULL,
  momentum = NULL
)

Arguments

mode

A single character string for the type of model. Possible values for this model are "unknown", "regression", or "classification".

epochs

(int) Number of training epochs.

penalty

This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.

batch_size

(int) Number of examples per batch, large batch sizes are recommended. (default: 1024^2)

learn_rate

initial learning rate for the optimizer.

decision_width

(int) Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. Values typically range from 8 to 64.

attention_width

(int) Width of the attention embedding for each mask. According to the paper n_d=n_a is usually a good choice. (default=8)

num_steps

(int) Number of steps in the architecture (usually between 3 and 10)

feature_reusage

(float) This is the coefficient for feature reusage in the masks. A value close to 1 will make mask selection least correlated between layers. Values range from 1.0 to 2.0.

virtual_batch_size

(int) Size of the mini batches used for "Ghost Batch Normalization" (default=256^2)

num_independent

Number of independent Gated Linear Units layers at each step. Usual values range from 1 to 5.

num_shared

Number of shared Gated Linear Units at each step Usual values range from 1 to 5

momentum

Momentum for batch normalization, typically ranges from 0.01 to 0.4 (default=0.02)

Value

A TabNet parsnip instance. It can be used to fit tabnet models using parsnip machinery.

Threading

TabNet uses torch as its backend for computation and torch uses all available threads by default.

You can control the number of threads used by torch with:

torch::torch_set_num_threads(1)
torch::torch_set_num_interop_threads(1)

See Also

tabnet_fit

Examples

if (torch::torch_is_installed()) {
library(parsnip)
data("ames", package = "modeldata")
model <- tabnet() %>%
  set_mode("regression") %>%
  set_engine("torch")
model %>%
  fit(Sale_Price ~ ., data = ames)
}


tabnet documentation built on May 31, 2023, 6:27 p.m.