Nothing
#' Transforms input data into a list of_tensors and parameters for model input
#'
#' The 3 torch tensors being
#' $x , $x_na_mask, $y
#' and parameters being
#' cat_idx the vector of x categorical predictor index
#' cat_dims the vector of number of levels of each x categorical predictor
#' input_dim the number of col in `x`
#' output_dim the `ncol(y)` in case of (multi-outcome) regression or
#' the `nlevels(y)` in case of classification or
#' the vector of `nlevels(y)` in case of multi-outcome classification
#'
#' @param x a data frame
#' @param y a response vector
#' @noRd
resolve_data <- function(x, y) {
cat_idx <- which(sapply(x, is.factor))
cat_dims <- sapply(cat_idx, function(i) nlevels(x[[i]]))
# convert factors into integers
if (length(cat_idx) > 0) {
x[,cat_idx] <- sapply(cat_idx, function(i) as.integer(x[[i]]))
} else {
# prevent empty cat idx
cat_idx <- 0L
cat_dims <- 0L
}
x_tensor <- torch::torch_tensor(as.matrix(x), dtype = torch::torch_float())
x_na_mask <- x %>% is.na %>% as.matrix %>% torch::torch_tensor(dtype = torch::torch_bool())
# convert factors to integers, based on the class of target first column
# TODO do not assume but assert type-consistency of all y cols
# and record output_dim
if (is.factor(y[[1]])) {
y_tensor <- torch::torch_tensor(sapply(y, function(i) as.integer(i)), dtype = torch::torch_long())
if (is.atomic(y)) {
output_dim <- nlevels(y)
} else {
output_dim <- sapply(y, function(i) nlevels(i))
}
} else {
y_tensor <- torch::torch_tensor(as.matrix(y), dtype = torch::torch_float())
output_dim <- ncol(y)
}
input_dim <- ncol(x)
list(x = x_tensor, x_na_mask = x_na_mask, y = y_tensor,
cat_idx = cat_idx,
output_dim = output_dim,
input_dim = input_dim, cat_dims = cat_dims)
}
#' Configuration for TabNet models
#'
#' @param batch_size (int) Number of examples per batch, large batch sizes are
#' recommended. (default: 1024^2)
#' @param penalty This is the extra sparsity loss coefficient as proposed
#' in the original paper. The bigger this coefficient is, the sparser your model
#' will be in terms of feature selection. Depending on the difficulty of your
#' problem, reducing this value could help.
#' @param clip_value If a float is given this will clip the gradient at
#' clip_value. Pass `NULL` to not clip.
#' @param loss (character or function) Loss function for training (default to mse
#' for regression and cross entropy for classification)
#' @param epochs (int) Number of training epochs.
#' @param drop_last (logical) Whether to drop last batch if not complete during
#' training
#' @param decision_width (int) Width of the decision prediction layer. Bigger values gives
#' more capacity to the model with the risk of overfitting. Values typically
#' range from 8 to 64.
#' @param attention_width (int) Width of the attention embedding for each mask. According to
#' the paper n_d=n_a is usually a good choice. (default=8)
#' @param num_steps (int) Number of steps in the architecture
#' (usually between 3 and 10)
#' @param feature_reusage (float) This is the coefficient for feature reusage in the masks.
#' A value close to 1 will make mask selection least correlated between layers.
#' Values range from 1.0 to 2.0.
#' @param mask_type (character) Final layer of feature selector in the attentive_transformer
#' block, either `"sparsemax"` or `"entmax"`.Defaults to `"sparsemax"`.
#' @param virtual_batch_size (int) Size of the mini batches used for
#' "Ghost Batch Normalization" (default=256^2)
#' @param learn_rate initial learning rate for the optimizer.
#' @param optimizer the optimization method. currently only 'adam' is supported,
#' you can also pass any torch optimizer function.
#' @param valid_split (float) The fraction of the dataset used for validation.
#' @param num_independent Number of independent Gated Linear Units layers at each step.
#' Usual values range from 1 to 5.
#' @param num_shared Number of shared Gated Linear Units at each step Usual values
#' range from 1 to 5
#' @param verbose (logical) Whether to print progress and loss values during
#' training.
#' @param lr_scheduler if `NULL`, no learning rate decay is used. if "step"
#' decays the learning rate by `lr_decay` every `step_size` epochs. It can
#' also be a [torch::lr_scheduler] function that only takes the optimizer
#' as parameter. The `step` method is called once per epoch.
#' @param lr_decay multiplies the initial learning rate by `lr_decay` every
#' `step_size` epochs. Unused if `lr_scheduler` is a `torch::lr_scheduler`
#' or `NULL`.
#' @param step_size the learning rate scheduler step size. Unused if
#' `lr_scheduler` is a `torch::lr_scheduler` or `NULL`.
#' @param cat_emb_dim Embedding size for categorical features (default=1)
#' @param momentum Momentum for batch normalization, typically ranges from 0.01
#' to 0.4 (default=0.02)
#' @param pretraining_ratio Ratio of features to mask for reconstruction during
#' pretraining. Ranges from 0 to 1 (default=0.5)
#' @param checkpoint_epochs checkpoint model weights and architecture every
#' `checkpoint_epochs`. (default is 10). This may cause large memory usage.
#' Use `0` to disable checkpoints.
#' @param device the device to use for training. "cpu" or "cuda". The default ("auto")
#' uses to "cuda" if it's available, otherwise uses "cpu".
#' @param importance_sample_size sample of the dataset to compute importance metrics.
#' If the dataset is larger than 1e5 obs we will use a sample of size 1e5 and
#' display a warning.
#' @param early_stopping_monitor Metric to monitor for early_stopping. One of "valid_loss", "train_loss" or "auto" (defaults to "auto").
#' @param early_stopping_tolerance Minimum relative improvement to reset the patience counter.
#' 0.01 for 1% tolerance (default 0)
#' @param early_stopping_patience Number of epochs without improving until stopping training. (default=5)
#' @param num_workers (int, optional): how many subprocesses to use for data
#' loading. 0 means that the data will be loaded in the main process.
#' (default: `0`)
#' @param skip_importance if feature importance calculation should be skipped (default: `FALSE`)
#' @return A named list with all hyperparameters of the TabNet implementation.
#'
#' @export
tabnet_config <- function(batch_size = 1024^2,
penalty = 1e-3,
clip_value = NULL,
loss = "auto",
epochs = 5,
drop_last = FALSE,
decision_width = NULL,
attention_width = NULL,
num_steps = 3,
feature_reusage = 1.3,
mask_type = "sparsemax",
virtual_batch_size = 256^2,
valid_split = 0,
learn_rate = 2e-2,
optimizer = "adam",
lr_scheduler = NULL,
lr_decay = 0.1,
step_size = 30,
checkpoint_epochs = 10,
cat_emb_dim = 1,
num_independent = 2,
num_shared = 2,
momentum = 0.02,
pretraining_ratio = 0.5,
verbose = FALSE,
device = "auto",
importance_sample_size = NULL,
early_stopping_monitor = "auto",
early_stopping_tolerance = 0,
early_stopping_patience = 0L,
num_workers=0L,
skip_importance = FALSE) {
if (is.null(decision_width) && is.null(attention_width)) {
decision_width <- 8 # default is 8
}
if (is.null(attention_width))
attention_width <- decision_width
if (is.null(decision_width))
decision_width <- attention_width
list(
batch_size = batch_size,
lambda_sparse = penalty,
clip_value = clip_value,
loss = loss,
epochs = epochs,
drop_last = drop_last,
n_d = decision_width,
n_a = attention_width,
n_steps = num_steps,
gamma = feature_reusage,
mask_type = mask_type,
virtual_batch_size = virtual_batch_size,
valid_split = valid_split,
learn_rate = learn_rate,
optimizer = optimizer,
lr_scheduler = lr_scheduler,
lr_decay = lr_decay,
step_size = step_size,
checkpoint_epochs = checkpoint_epochs,
cat_emb_dim = cat_emb_dim,
n_independent = num_independent,
n_shared = num_shared,
momentum = momentum,
pretraining_ratio = pretraining_ratio,
verbose = verbose,
device = device,
importance_sample_size = importance_sample_size,
early_stopping_monitor = resolve_early_stop_monitor(early_stopping_monitor, valid_split),
early_stopping_tolerance = early_stopping_tolerance,
early_stopping_patience = early_stopping_patience,
early_stopping = !(early_stopping_tolerance==0 || early_stopping_patience==0),
num_workers = num_workers,
skip_importance = skip_importance
)
}
resolve_loss <- function(loss, dtype) {
if (is.function(loss))
loss_fn <- loss
else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long())
loss_fn <- torch::nn_mse_loss()
else if (loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long())
loss_fn <- torch::nn_cross_entropy_loss()
else
rlang::abort(paste0(loss," is not a valid loss for outcome of type ",dtype))
loss_fn
}
resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) {
if (early_stopping_monitor %in% c("valid_loss", "auto") && valid_split > 0)
early_stopping_monitor <- "valid_loss"
else if (early_stopping_monitor %in% c("train_loss", "auto"))
early_stopping_monitor <- "train_loss"
else
rlang::abort(paste0(early_stopping_monitor," is not a valid early stopping metric to monitor with `valid_split`=",valid_split))
early_stopping_monitor
}
train_batch <- function(network, optimizer, batch, config) {
# forward pass
output <- network(batch$x, batch$x_na_mask)
# if target is_multi_outcome, loss has to be applied to each label-group
if (max(batch$output_dim$shape) > 1) {
# TODO maybe torch_stack here would help loss$backward and better to shift right torch_sum at the end ?
outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu"))
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(.x, .y$squeeze(2))
)),
dim = 1)
} else {
if (batch$y$dtype == torch::torch_long()) {
# classifier needs a squeeze for bce loss
loss <- config$loss_fn(output[[1]], batch$y$squeeze(2))
} else {
loss <- config$loss_fn(output[[1]], batch$y)
}
}
# Add the overall sparsity loss
loss <- loss - config$lambda_sparse * output[[2]]
# step of the optimization
optimizer$zero_grad()
loss$backward()
if (!is.null(config$clip_value)) {
torch::nn_utils_clip_grad_norm_(network$parameters, config$clip_value)
}
optimizer$step()
list(
loss = loss$item()
)
}
valid_batch <- function(network, batch, config) {
# forward pass
output <- network(batch$x, batch$x_na_mask)
# loss has to be applied to each label-group when output_dim is a vector
if (max(batch$output_dim$shape) > 1) {
# TODO maybe torch_stack here would help loss$backward and better to shift right torch_sum at the end ?
outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu"))
loss <- torch::torch_sum(torch::torch_stack(purrr::pmap(
list(
torch::torch_split(output[[1]], outcome_nlevels, dim = 2),
torch::torch_split(batch$y, rep(1, length(outcome_nlevels)), dim = 2)
),
~config$loss_fn(.x, .y$squeeze(2))
)),
dim = 1)
} else {
if (batch$y$dtype == torch::torch_long()) {
# classifier needs a squeeze for bce loss
loss <- config$loss_fn(output[[1]], batch$y$squeeze(2))
} else {
loss <- config$loss_fn(output[[1]], batch$y)
}
}
# Add the overall sparsity loss
loss <- loss - config$lambda_sparse * output[[2]]
list(
loss = loss$item()
)
}
get_device_from_config <- function(config) {
if (config$device == "auto") {
if (torch::cuda_is_available()){
device <- "cuda"
} else if (torch::backends_mps_is_available()) {
device <- "mps"
} else {
device <- "cpu"
}
} else {
device <- config$device
}
device
}
tabnet_initialize <- function(x, y, config = tabnet_config()) {
torch::torch_manual_seed(sample.int(1e6, 1))
has_valid <- config$valid_split > 0
device <- get_device_from_config(config)
if (has_valid) {
n <- nrow(x)
valid_idx <- sample.int(n, n*config$valid_split)
valid_x <- x[valid_idx, ]
valid_y <- y[valid_idx, ]
train_y <- y[-valid_idx, ]
valid_ds <- torch::dataset(
initialize = function() {},
.getbatch = function(batch) {resolve_data(valid_x[batch,], valid_y[batch, ])},
.length = function() {nrow(valid_x)}
)()
x <- x[-valid_idx, ]
y <- train_y
}
# training dataset
train_ds <- torch::dataset(
initialize = function() {},
.getbatch = function(batch) {resolve_data(x[batch, ], y[batch, ])},
.length = function() {nrow(x)}
)()
# we can get training_set parameters from the 2 first samples
train <- train_ds$.getbatch(batch = c(1:2))
# resolve loss
config$loss_fn <- resolve_loss(config$loss, train$y$dtype)
# create network
network <- tabnet_nn(
input_dim = train$input_dim,
output_dim = train$output_dim,
cat_idxs = train$cat_idx,
cat_dims = train$cat_dims,
n_d = config$n_d,
n_a = config$n_a,
n_steps = config$n_steps,
gamma = config$gamma,
virtual_batch_size = config$virtual_batch_size,
cat_emb_dim = config$cat_emb_dim,
n_independent = config$n_independent,
n_shared = config$n_shared,
momentum = config$momentum,
mask_type = config$mask_type
)
# main loop
metrics <- list()
checkpoints <- list()
importances <- tibble::tibble(
variables = colnames(x),
importance = NA
)
list(
network = network,
metrics = metrics,
config = config,
checkpoints = checkpoints,
importances = importances
)
}
tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_shift = 0L) {
stopifnot("tabnet_model shall be initialised or pretrained" = (length(obj$fit$network) > 0))
torch::torch_manual_seed(sample.int(1e6, 1))
device <- get_device_from_config(config)
# validation dataset & dataloaders
has_valid <- config$valid_split > 0
if (has_valid) {
n <- nrow(x)
valid_idx <- sample.int(n, n*config$valid_split)
valid_x <- x[valid_idx, ]
valid_y <- y[valid_idx, ]
train_y <- y[-valid_idx, ]
valid_ds <- torch::dataset(
initialize = function() {},
.getbatch = function(batch) {resolve_data(valid_x[batch,], valid_y[batch,])},
.length = function() {nrow(valid_x)}
)()
valid_dl <- torch::dataloader(
valid_ds,
batch_size = config$batch_size,
shuffle = FALSE,
num_workers = config$num_workers
)
x <- x[-valid_idx, ]
y <- train_y
}
# training dataset & dataloader
train_ds <- torch::dataset(
initialize = function() {},
.getbatch = function(batch) {resolve_data(x[batch,], y[batch,])},
.length = function() {nrow(x)}
)()
train_dl <- torch::dataloader(
train_ds,
batch_size = config$batch_size,
drop_last = config$drop_last,
shuffle = TRUE ,
num_workers = config$num_workers
)
# resolve loss
config$loss_fn <- resolve_loss(config$loss, train_ds$.getbatch(batch = c(1:2))$y$dtype)
# restore network from model and send it to device
network <- obj$fit$network
network$to(device = device)
# define optimizer
if (rlang::is_function(config$optimizer)) {
optimizer <- config$optimizer(network$parameters, config$learn_rate)
} else if (rlang::is_scalar_character(config$optimizer)) {
if (config$optimizer == "adam")
optimizer <- torch::optim_adam(network$parameters, lr = config$learn_rate)
else
rlang::abort("Currently only the 'adam' optimizer is supported.")
}
# define scheduler
if (is.null(config$lr_scheduler)) {
scheduler <- list(step = function() {})
} else if (rlang::is_function(config$lr_scheduler)) {
scheduler <- config$lr_scheduler(optimizer)
} else if (config$lr_scheduler == "step") {
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
}
# restore previous metrics & checkpoints
metrics <- obj$fit$metrics
checkpoints <- obj$fit$checkpoints
patience_counter <- 0L
# main loop
for (epoch in seq_len(config$epochs) + epoch_shift) {
metrics[[epoch]] <- list()
train_metrics <- c()
valid_metrics <- c()
network$train()
if (config$verbose)
pb <- progress::progress_bar$new(
total = length(train_dl),
format = "[:bar] loss= :loss"
)
coro::loop(for (batch in train_dl) {
m <- train_batch(network, optimizer, to_device(batch, device), config)
if (config$verbose) pb$tick(tokens = m)
train_metrics <- c(train_metrics, m)
})
metrics[[epoch]][["train"]] <- transpose_metrics(train_metrics)
if (config$checkpoint_epochs > 0 && epoch %% config$checkpoint_epochs == 0) {
network$to(device = "cpu")
checkpoints[[length(checkpoints) + 1]] <- model_to_raw(network)
network$to(device = device)
}
network$eval()
if (has_valid) {
coro::loop(for (batch in valid_dl) {
m <- valid_batch(network, to_device(batch, device), config)
valid_metrics <- c(valid_metrics, m)
})
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)
}
message <- sprintf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train$loss))
if (has_valid)
message <- paste0(message, sprintf(" Valid loss: %3f", mean(metrics[[epoch]]$valid$loss)))
if (config$verbose)
rlang::inform(message)
# Early-stopping checks
if (config$early_stopping && config$early_stopping_monitor=="valid_loss"){
current_loss <- mean(metrics[[epoch]]$valid$loss)
} else {
current_loss <- mean(metrics[[epoch]]$train$loss)
}
if (config$early_stopping && epoch > 1+epoch_shift) {
# compute relative change, and compare to best_metric
change <- (current_loss - best_metric) / current_loss
if (change > config$early_stopping_tolerance){
patience_counter <- patience_counter + 1
if (patience_counter >= config$early_stopping_patience){
if (config$verbose)
rlang::inform(sprintf("Early stopping at epoch %03d", epoch))
break
}
} else {
# reset the patience counter
best_metric <- current_loss
patience_counter <- 0L
}
}
if (config$early_stopping && epoch == 1+epoch_shift) {
# initialise best_metric
best_metric <- current_loss
}
scheduler$step()
}
network$to(device = "cpu")
if(!config$skip_importance) {
importance_sample_size <- config$importance_sample_size
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
rlang::warn(c(glue::glue("Computing importances for a dataset with size {train_ds$.length()}."),
"This can consume too much memory. We are going to use a sample of size 1e5",
"You can disable this message by using the `importance_sample_size` argument."))
importance_sample_size <- 1e5
}
indexes <- as.numeric(torch::torch_randint(
1, train_ds$.length(), min(importance_sample_size, train_ds$.length()),
dtype = torch::torch_long()
))
importances <- tibble::tibble(
variables = colnames(x),
importance = compute_feature_importance(
network,
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu"))
)
} else {
importances <- NULL
}
list(
network = network,
metrics = metrics,
config = config,
checkpoints = checkpoints,
importances = importances
)
}
predict_impl <- function(obj, x, batch_size = 1e5) {
# prediction dataset
device = obj$fit$config$device
predict_ds <- torch::dataset(
initialize = function() {},
.getbatch = function(batch) {resolve_data(x[batch,], rep(1, nrow(x)))},
.length = function() {nrow(x)}
)()
network <- obj$fit$network
num_workers <- obj$fit$config$num_workers
yhat <- c()
network$eval()
predict_dl <- torch::dataloader(
predict_ds,
batch_size = batch_size,
drop_last = FALSE,
shuffle = FALSE ,
# num_workers = num_workers
num_workers = 0L
)
coro::loop(for (batch in predict_dl) {
batch <- to_device(batch, device)
yhat <- c(yhat, network(batch$x, batch$x_na_mask)[[1]])
})
# bind rows of the batches
torch::torch_cat(yhat)
}
predict_impl_numeric <- function(obj, x, batch_size) {
p <- as.matrix(predict_impl(obj, x, batch_size))
hardhat::spruce_numeric(as.numeric(p))
}
predict_impl_numeric_multiple <- function(obj, x, batch_size) {
p <- as.matrix(predict_impl(obj, x, batch_size))
# TODO use a cleaner function to turn matrix into vectors
hardhat::spruce_numeric_multiple(!!!purrr::map(1:ncol(p), ~p[,.x]))
}
#' single-outcome level blueprint
#'
#' @param obj : a tabnet object
#'
#' @return : outcome levels
#' @noRd
get_blueprint_levels <- function(obj) {
levels(obj$blueprint$ptypes$outcomes[[1]])
}
#' multi-outcome levels blueprint
#'
#' @param obj : a tabnet object
#'
#' @return : a list of levels vectors for each outcome
#' @noRd
get_blueprint_levels_multiple <- function(obj) {
purrr::map(obj$blueprint$ptypes$outcomes, levels) %>%
rlang::set_names(names(obj$blueprint$ptypes$outcomes))
}
predict_impl_prob <- function(obj, x, batch_size) {
p <- predict_impl(obj, x, batch_size)
p <- torch::nnf_softmax(p, dim = 2)
p <- as.matrix(p)
hardhat::spruce_prob(get_blueprint_levels(obj), p)
}
predict_impl_prob_multiple <- function(obj, x, batch_size, outcome_nlevels) {
p <- predict_impl(obj, x, batch_size)
p <- torch::nnf_softmax(p, dim = 2)
p <- as.matrix(p)
# TODO use a cleaner function to turn matrix into vectors
p_blueprint <- get_blueprint_levels_multiple(obj)
p_probs <- purrr::map(torch::torch_split(p, outcome_nlevels, dim = 2),
as.matrix)
hardhat::spruce_prob_multiple(!!!purrr::pmap(
list(p_blueprint, p_probs),
# TODO BUG each element of `...` must be a tibble, not a list.
~hardhat::spruce_prob(.x, .y)) %>%
rlang::set_names(names(p_blueprint))
)
}
predict_impl_class <- function(obj, x, batch_size) {
p <- predict_impl(obj, x, batch_size)
p_idx <- as.integer(torch::torch_max(p, dim = 2)[[2]])
p_idx <- get_blueprint_levels(obj)[p_idx]
p <- factor(p_idx, levels = get_blueprint_levels(obj))
hardhat::spruce_class(p)
}
predict_impl_class_multiple <- function(obj, x, batch_size, outcome_nlevels) {
p <- predict_impl(obj, x, batch_size)
p_levels <- get_blueprint_levels_multiple(obj)
p_idx <- purrr::map(
torch::torch_split(p, outcome_nlevels, dim = 2),
~as.integer(torch::torch_max(.x, dim = 2)[[2]])
) %>% rlang::set_names(names(p_levels))
p_factor_lst <- purrr::pmap(
list(p_idx, p_levels),
~factor(.y[.x], levels = .y)
)
hardhat::spruce_class_multiple(!!!p_factor_lst)
}
to_device <- function(x, device) {
lapply(x, function(x) {
if (inherits(x, "torch_tensor")) {
x$to(device=device)
} else if (is.list(x)) {
lapply(x, to_device)
} else {
x
}
})
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.