knitr::opts_chunk$set(
  collapse = TRUE,
  eval = FALSE,
  comment = "#>"
)
library(torch)
library(torchtransformers)
library(luz)

library(dlr)
library(dplyr)
library(jsonlite)
library(wordpiece)
library(wordpiece.data)

This vignette will walk through the process of training a "triplet" model to produce sentence embeddings for some task. This model demonstrates techniques for ensuring that data is tokenized for a model.

Each training example for this type of model consists of three pieces of input text: the anchor, the positive example, and the negative example. The training loss for the model will be the difference between

The loss is minimized during training, so the model learns to put the anchor closer to the positive example than to the negative example in embedding space.

After training, the hope is that the model has learned to make useful representations for the input examples. The model can then be used to generate sentence-level embedding vectors for any input sentences.

Get the data

The data we will use for this vignette is derived from NLI datasets, and can be obtained online. We use the read_or_cache() function from {dlr} to avoid repeatedly processing the same dataset.

# dlr uses a processor function that takes the path to a temp file as its first
# argument. Here we get the data to the sort of point from which you might start
# with your own text data.
paraphrase_processor <- function(source_file) {
  return(
    jsonlite::stream_in(file(source_file)) %>% 
      dplyr::as_tibble() %>% 
      dplyr::rename(
        anchor = "V1",
        positive = "V2",
        negative = "V3"
      ) %>% 
      # The dataset contains some long, rambling sentences that we don't want
      # for this demonstration. Filter out those entire rows.
      dplyr::mutate(
        max_chars = pmax(
          nchar(anchor), nchar(positive), nchar(negative)
        )
      ) %>% 
      dplyr::filter(max_chars <= 300) %>% 
      dplyr::select(-max_chars)
  )
}

data_url <- "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz"

paraphrases <- dlr::read_or_cache(
  source_path = data_url,
  appname = "torchtransformers",
  process_f = paraphrase_processor
)

paraphrases$anchor[[1]]
# [1] "A person on a horse jumps over a broken down airplane."

paraphrases$positive[[1]]
# [2] "A person is outdoors, on a horse."                     

paraphrases$negative[[1]]
# [3] "A person is at a diner, ordering an omelette."

Define the model

Since the dataset is more complicated than other examples, we'll first define our model, and then construct a dataset that will work properly for that model.

We're going to build a "triplet" model with a BERT pre-trained "spine" (specifically "bert_tiny_uncased"). We use the model_bert_pretrained function to create a BERT module with pre-trained weights loaded. We use the final [CLS] token as the output, and then a dense layer after that to map into our paraphrase embedding space.

spine_with_pooling <- torch::nn_module(
  classname = "pooled_spine",
  initialize = function() {
    bert_type <- "bert_tiny_uncased"

    self$bert <- model_bert_pretrained(bert_type)

    # We'll add a final dense layer after the pretrained model to map the [CLS]
    # handler token into our embedding_size-dimensional space.
    embedding_size <- config_bert(
      bert_type = bert_type, 
      parameter = "embedding_size"
    ) 

    self$linear <- torch::nn_linear(
      in_features = embedding_size, 
      out_features = embedding_size
    )
  },
  forward = function(x) {
    # Process the input through the pretrained BERT model.
    output <- self$bert(x)

    # Take the output embeddings from the final BERT layer.
    output <- output$output_embeddings
    output <- output[[length(output)]]

    # Take the [CLS] token embedding (token 1 in each example) for pooling. The
    # three dimensions are: example (length = batch_size), token (length =
    # n_tokens), embeddings (length = defined by the model type).
    output <- output[ , 1, ]

    # Apply the final dense layer to the pooled output.
    output <- self$linear(output)
    return(output)
  }
)

The full model uses that same spine for the three different inputs. We use torch::nn_triplet_margin_loss to measure the loss using the three outputs.

triplet_text_model <- torch::nn_module(
  classname = "triplet_text",
  initialize = function() {
    # Share the same spine parameters between all three inputs.
    self$spine <- spine_with_pooling()
    self$criterion <- torch::nn_triplet_margin_loss()
  },
  # We define a loss method for the model. {luz} will use this method during
  # fitting.
  loss = function(input, ...) {
    # input contains anchor, positive, and negative, each of which is composed
    # of token_ids and token_type_ids. Run each input through the spine.
    embeddings <- lapply(input, self$spine)

    # The `nn_triplet_margin_loss` criterion measures the distance between the
    # three samples, and attempts to ensure that the positive example is closer
    # to the anchor than the negative example is.
    loss <- self$criterion(
      embeddings$anchor,
      embeddings$positive,
      embeddings$negative
    )
    return(loss)
  }
)

Prepare the data for torch

The input to our triplet_text_model model is a list with elements anchor, positive, and negative. We construct a custom dataset constructor to prepare our data for that model.

In this model, each of the three elements functions as a separate input to a BERT model. Therefore our dataset constructor will prepare each one as if it is a separate dataset.

triplet_text_dataset <- torch::dataset(
  name = "triplet_text_dataset",
  initialize = function(df) {
    # Use dataset_bert_pretrained() to prepare each dataset. We'll prepare each
    # for a "bert_tiny_uncased" pretrained model. Note that we only want to keep
    # part of each dataset, so we construct them and then save the torch_data
    # from each to this dataset.
    bert_type <- "bert_tiny_uncased"
    # We don't need the full 512 tokens this bert_type allows.
    n_tokens <- 128L
    anchor_ds <- dataset_bert_pretrained(
      x = df$anchor,
      bert_type = bert_type,
      n_tokens = n_tokens
    )
    self$anchor <- anchor_ds$torch_data$x$token_ids
    positive_ds <- dataset_bert_pretrained(
      x = df$positive,
      bert_type = bert_type,
      n_tokens = n_tokens
    )
    self$positive <- positive_ds$torch_data$x$token_ids
    negative_ds <- dataset_bert_pretrained(
      x = df$negative,
      bert_type = bert_type,
      n_tokens = n_tokens
    )
    self$negative <- negative_ds$torch_data$x$token_ids

    # These all have the same token_type_ids.
    self$token_type_ids <- anchor_ds$torch_data$x$token_type_ids
  },
  # We extract subsets of this data using an index vector.
  .getbatch = function(index) {
    return(
      list(
        list(
          anchor = list(
            token_ids = self$anchor[index, ],
            token_type_ids = self$token_type_ids[index, ]
          ),
          positive = list(
            token_ids = self$positive[index, ],
            token_type_ids = self$token_type_ids[index, ]
          ),
          negative = list(
            token_ids = self$negative[index, ],
            token_type_ids = self$token_type_ids[index, ]
          )
        ),
        # We don't use a separate target for this model.
        list() 
      )
    )
  },
  # The dataset also needs a method to determine the length of the entire
  # dataset.
  .length = function() {
    dim(self$anchor)[[1]]
  }
)

We use the constructor to create a training dataset. We'll let {lux} split this dataset into training and validation. The dataset constructor performs the tokenization, so this step will be somewhat slow.

train_ds <- triplet_text_dataset(paraphrases)

Fit the model

We use {luz} to perform the actual model fitting. With these settings, this model gets about as good as it's going to get after a single epoch, so there's no need to train for additional epochs.

set.seed(12345)
torch::torch_manual_seed(123456)
fitted <- triplet_text_model %>% 
  luz::setup(optimizer = torch::optim_adam) %>% 
  fit(
    train_ds, 
    epochs = 1, 
    valid_data = 0.1,
    dataloader_options = list(batch_size = 256L)
  )

After a single epoch of training, the loss on the validation set is about 0.23.

Using the model

We'll use the weights from the spine of our model to map an input into embedding space.

model_weights <- fitted$model$state_dict()

We define a simpler model that takes a single type of input, and outputs sentence vectors to map our inputs into our paraphrase embedding space. This model uses the same parameters as the triplet model, but with a single call to the spine.

# The vectorizer model is the same spine that we used for the three inputs. We
# wrap it in a larger constructor so the weight names line up for easier
# loading.
vectorizer <- torch::nn_module(
  "vectorizer",
  initialize = function() {
    self$spine <- spine_with_pooling()
  },
  forward = function(input) {
    return(self$spine(input))
  }
)

vectorizer_model <- vectorizer()

# Load our trained weights into this model.
vectorizer_model$load_state_dict(model_weights)

We can use that model to map the positive paraphrases from our dataset into an embedding space. We'll work with a subset to make sure they can all fit in RAM.

vectorizer_model$eval()

paraphrase_space <- vectorizer_model(
  list(
    token_ids = train_ds$positive[1:10000, ], 
    token_type_ids = train_ds$token_type_ids[1:10000, ]
  )
)

A given anchor should, in theory, be closest to its associated paraphrase in that space.

# Map some of the anchors into that same embedding space.
anchors_in_space <- vectorizer_model(
  list(
    token_ids = train_ds$anchor[1:100, ],
    token_type_ids = train_ds$token_type_ids[1:100, ]
  )
)

# Create a simple pairwise distance calculator.
pair_dist <- torch::nn_pairwise_distance()

# I picked out the 8th example because its result is somewhat interesting.
anchor_id <- 8
distances <- pair_dist(anchors_in_space[anchor_id, ], paraphrase_space)
closest_match <- which(
  torch::as_array(distances) == min(torch::as_array(distances))
)

paraphrases$anchor[[anchor_id]]
#> [1] "Two women who just had lunch hugging and saying goodbye."

# The actual "correct" paraphrase.
paraphrases$positive[[anchor_id]]
#> [1] "There are two woman in this picture."

# The closest paraphrase in our embedding space.
paraphrases$positive[[closest_match]]
#> [1] "A couple is having lunch"

While this is incorrect, you can see that the model picked up on the relationship between "two" and "a couple", and also that the sentence had something to do with having lunch. Further refinement of this model would be necessary to make the results actually useful, but the general approach can lead to models that are useful for quickly matching a new piece of text to a large collection of known text. If we were using this for something real, we would want a separate test set that maps to the same set of paraphrases.



macmillancontentscience/torchtransformers documentation built on Aug. 6, 2023, 5:35 a.m.