knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 5, fig.height = 5 )
We're going to use the madgrad optimizer @Defazio
library(torchtabular) library(tidymodels) library(tidyverse) library(torch) library(luz) library(madgrad)
We will set the seeds to make our results reproducible.
torch_manual_seed(seed = 154) set.seed(154)
device <- ifelse(cuda_is_available(), 'cuda', 'cpu')
The income dataset is included with the torchtabular package.
data('income') glimpse(income)
First we will convert the target variable into an integer (0 and 1), and convert characters to factors so that our tabular dataset will identify them correctly.
income <- income %>% mutate(across(where(is.character), as_factor), income = as.numeric(income) - 1) glimpse(income)
We can now split the data into train and test sets.
split <- initial_split(income, prop = 0.7) train <- training(split) valid <- testing(split)
By creating a recipe, the tabular_dataset
function will automatically recognise categorical vs. continuous predictors.
recipe <- recipe(income, income ~ .) %>% step_scale(all_numeric_predictors()) %>% step_integer(all_nominal_predictors())
We can then pass this recipe to tabular_dataset
with the relevant split.
train_dset <- tabular_dataset(recipe, train) valid_dset <- tabular_dataset(recipe, valid)
Finally, we make a dataloader.
train_dl <- dataloader(train_dset, batch_size = 64, shuffle = TRUE) valid_dl <- dataloader(valid_dset, batch_size = 512, shuffle = FALSE)
We can now train our model using luz for 10 epochs.
n_epochs <- 5
Let's setup the model with our hyperparameters. We will use MADGRAD as our optimizer - it works well and converges rapidly.
model_setup <- tabtransformer %>% setup( loss = nn_bce_with_logits_loss(), optimizer = madgrad::optim_madgrad, metrics = list( luz_metric_binary_auroc(from_logits = TRUE) ) ) %>% set_hparams(categories = train_dset$categories, num_continuous = train_dset$num_continuous, dim_out = 1, attention = "both", attention_type = "signed", is_first = TRUE, dim = 32, depth = 1, heads_selfattn = 32, heads_intersample = 32, dim_heads_selfattn = 16, dim_heads_intersample = 64, attn_dropout = 0.1, ff_dropout = 0.8, embedding_dropout = 0.0, mlp_dropout = 0.0, mlp_hidden_mult = c(6, 4, 2), softmax_mod = 1.0, is_softmax_mod = 1.0, skip = FALSE, device = device) %>% set_opt_hparams(lr = 2e-3)
Finally, we can fit the model. We have set verbose to FALSE so it doesn't fill our console. We can plot the loss and metrics after training to inspect how we did.
fitted <- model_setup %>% fit(train_dl, epochs = n_epochs, valid_data = valid_dl, verbose = FALSE)
Plotting the training performance progress:
metrics <- get_metrics(fitted) metrics %>% ggplot(aes(x = epoch, y = value, col = set)) + geom_line() + facet_wrap(vars(metric), scales = "free_y") + theme_bw()
We can improve the quality of our prediction by using large batches. The inter-sample attention layer can pay attention to all the other data points in a batch to make each prediction.
pred_dl <- dataloader(valid_dset, batch_size = 5000, shuffle = FALSE) preds <- predict(fitted, valid_dl)$squeeze(-1) preds <- as_array(preds) truth <- as_factor(ifelse(valid$income == 1, "High", "Low")) roc_auc_vec(truth = truth, estimate = preds, event_level = "second")
We can now interrogate our model a little further by looking at the attention heads.
The attention_heads
function can be used to pull the attention heads from the first 2000 rows in the validation dataset. These attention heads are averaged to get the average attention weights between two variables. We want to run this with a large batch, so will run this on the cpu to take advantage of the larger RAM.
heads <- attention_heads(fitted, valid_dset, n = 2000)
This data is represented nicely using a heatmap.
heatmap(heads)
The intersample attention heads can be pulled using the intersample_attention_heads
function.
is_heads <- intersample_attention_heads(fitted, valid_dset, n = 2000)
These attention heads lend themselves to clustering. We will start by reducing the number of dimensions using UMAP.
library(uwot) library(dbscan) library(fpc) set.seed(132) mapped <- umap(is_heads, pca = NULL, n_threads = 4, n_epochs = 500, min_dist = 0.0, n_neighbors = 30, negative_sample_rate = 15, local_connectivity = 2, spread = 3, metric='correlation') umap_comp <- as_tibble(mapped, .name_repair = ~ paste0("C", 1:2)) plotting_data <- umap_comp %>% bind_cols(valid[1:2000,]) %>% mutate(income = as_factor(income)) plotting_data %>% ggplot(aes(x = C1, y = C2, col = income)) + geom_point()
kNNdistplot(umap_comp, k = 10) abline(h = 2, lty = 2)
scanned <- dbscan(umap_comp, eps = 2, MinPts = 10) plotting_data %>% ggplot(aes(x = C1, y = C2, col = as_factor(scanned$cluster))) + geom_point()
We can now look to see if any of our predictors differ between these clusters
library(patchwork) p1 <- plotting_data %>% ggplot(aes(x = C1, y = C2, col = relationship)) + geom_point() p2 <- plotting_data %>% ggplot(aes(x = C1, y = C2, col = `marital-status`)) + geom_point() + scale_color_viridis_d() p3 <- plotting_data %>% ggplot(aes(x = C1, y = C2, col = `hours-per-week`)) + geom_point() + scale_color_viridis_c() p1 + p2 + p3 + plot_layout(ncol = 1)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.