knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 5,
  fig.height = 5
)

Setup

We're going to use the madgrad optimizer @Defazio

library(torchtabular)
library(tidymodels)
library(tidyverse)
library(torch)
library(luz)
library(madgrad)

We will set the seeds to make our results reproducible.

torch_manual_seed(seed = 154)
set.seed(154)

Check for GPU and assign device

device <- ifelse(cuda_is_available(), 'cuda', 'cpu')

Load data

The income dataset is included with the torchtabular package.

data('income')
glimpse(income)

Prepare data

First we will convert the target variable into an integer (0 and 1), and convert characters to factors so that our tabular dataset will identify them correctly.

income <- income %>%
  mutate(across(where(is.character), as_factor),
         income = as.numeric(income) - 1)

glimpse(income)

We can now split the data into train and test sets.

split <- initial_split(income, prop = 0.7)
train <- training(split)
valid <- testing(split)

By creating a recipe, the tabular_dataset function will automatically recognise categorical vs. continuous predictors.

recipe <- recipe(income, income ~ .) %>%
  step_scale(all_numeric_predictors()) %>%
  step_integer(all_nominal_predictors())

We can then pass this recipe to tabular_dataset with the relevant split.

train_dset <- tabular_dataset(recipe, train)
valid_dset <- tabular_dataset(recipe, valid)

Finally, we make a dataloader.

train_dl <- dataloader(train_dset,
                       batch_size = 64,
                       shuffle = TRUE)

valid_dl <- dataloader(valid_dset,
                       batch_size = 512,
                       shuffle = FALSE)

Training

We can now train our model using luz for 10 epochs.

n_epochs <- 5

Let's setup the model with our hyperparameters. We will use MADGRAD as our optimizer - it works well and converges rapidly.

model_setup <- tabtransformer %>%
  setup(
    loss = nn_bce_with_logits_loss(),
    optimizer = madgrad::optim_madgrad,
    metrics = list(
      luz_metric_binary_auroc(from_logits = TRUE)
    )
  ) %>%
  set_hparams(categories = train_dset$categories,
              num_continuous = train_dset$num_continuous,
              dim_out = 1,
              attention = "both",
              attention_type = "signed",
              is_first = TRUE,
              dim = 32,
              depth = 1,
              heads_selfattn = 32,
              heads_intersample = 32,
              dim_heads_selfattn = 16,
              dim_heads_intersample = 64,
              attn_dropout = 0.1,
              ff_dropout = 0.8,
              embedding_dropout = 0.0,
              mlp_dropout = 0.0,
              mlp_hidden_mult = c(6, 4, 2),
              softmax_mod = 1.0,
              is_softmax_mod = 1.0,
              skip = FALSE,
              device = device) %>% 
  set_opt_hparams(lr = 2e-3)

Finally, we can fit the model. We have set verbose to FALSE so it doesn't fill our console. We can plot the loss and metrics after training to inspect how we did.

fitted <- model_setup %>% 
  fit(train_dl,
      epochs = n_epochs,
      valid_data = valid_dl,
      verbose = FALSE)

Plotting the training performance progress:

metrics <- get_metrics(fitted)

metrics %>% 
  ggplot(aes(x = epoch, y = value, col = set)) + 
  geom_line() + 
  facet_wrap(vars(metric), scales = "free_y") +
  theme_bw()

Improve Performance with Large Batch Testing

We can improve the quality of our prediction by using large batches. The inter-sample attention layer can pay attention to all the other data points in a batch to make each prediction.

pred_dl <- dataloader(valid_dset,
                       batch_size = 5000,
                       shuffle = FALSE)
preds <- predict(fitted, 
                 valid_dl)$squeeze(-1)

preds <- as_array(preds)
truth <- as_factor(ifelse(valid$income == 1, "High", "Low"))

roc_auc_vec(truth = truth, estimate = preds, event_level = "second")

Investigate the Model

We can now interrogate our model a little further by looking at the attention heads.

Attention heads

The attention_heads function can be used to pull the attention heads from the first 2000 rows in the validation dataset. These attention heads are averaged to get the average attention weights between two variables. We want to run this with a large batch, so will run this on the cpu to take advantage of the larger RAM.

heads <- attention_heads(fitted, valid_dset, n = 2000)

This data is represented nicely using a heatmap.

heatmap(heads)

Intersample attention heads

The intersample attention heads can be pulled using the intersample_attention_heads function.

is_heads <- intersample_attention_heads(fitted, valid_dset, n = 2000)

These attention heads lend themselves to clustering. We will start by reducing the number of dimensions using UMAP.

library(uwot)
library(dbscan)
library(fpc)

set.seed(132)
mapped <- umap(is_heads, 
               pca = NULL, 
               n_threads = 4, 
               n_epochs = 500,
               min_dist = 0.0,
               n_neighbors = 30,
               negative_sample_rate = 15,
               local_connectivity = 2,
               spread = 3,
               metric='correlation')
umap_comp <- as_tibble(mapped, .name_repair = ~ paste0("C", 1:2))

plotting_data <- umap_comp %>% 
  bind_cols(valid[1:2000,]) %>% 
  mutate(income = as_factor(income))

plotting_data %>% 
  ggplot(aes(x = C1, y = C2, col = income)) +
  geom_point()
kNNdistplot(umap_comp, k =  10)
abline(h = 2, lty = 2)
scanned <- dbscan(umap_comp, eps = 2, MinPts = 10)
plotting_data %>% 
  ggplot(aes(x = C1, y = C2, col = as_factor(scanned$cluster))) +
  geom_point()

We can now look to see if any of our predictors differ between these clusters

library(patchwork)
p1 <- plotting_data %>% 
  ggplot(aes(x = C1, y = C2, col = relationship)) +
  geom_point()

p2 <- plotting_data %>% 
  ggplot(aes(x = C1, y = C2, col = `marital-status`)) +
  geom_point() +
  scale_color_viridis_d()

p3 <- plotting_data %>% 
  ggplot(aes(x = C1, y = C2, col = `hours-per-week`)) +
  geom_point() +
  scale_color_viridis_c()

p1 + p2 + p3 + plot_layout(ncol = 1)


cmcmaster1/torchtabular documentation built on Dec. 19, 2021, 5:17 p.m.