knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  eval = FALSE
)
library(torch)
library(torchtransformers)
library(luz)

library(dlr)
library(dplyr)
library(yardstick)

Textual entailment is a common NLP task, and is included in the GLUE and SuperGLUE NLP benchmarks. The task consists of two pieces of text, a premise and a hypothesis. For example, this is a premise/hypothesis pair from the MultiNLI dataset (MNLI, described in more detail below):

In this case, the premise entails the hypothesis. This means that the hypothesis follows from the premise.

In contrast, this is another premise/hypothesis pair from the same dataset:

In this case, the premise contradicts the hypothesis. The premise lists things that are displayed in the museum, while the hypothesis asserts that the museum is empty.

Finally, this is another premise/hypothesis pair from MNLI:

While Horus was an Egyptian god, the premise doesn't mention that, so the premise neither entails nor contradicts the hypothesis. This pair is said to be neutral.

In this vignette, we'll use the MNLI dataset to fine-tune a BERT model for an entailment task.

The MNLI Dataset

The Multi-Genre Natural Language Inference (MultiNLI or MNLI) corpus was described in A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference (Williams et al., NAACL 2018). It includes 433k premise-hypothesis pairs, annotated with entailment information. The premises are divided into 10 genres. Five of the genres ("fiction", "government", "slate", "telephone", and "travel") are included in the training dataset, and the other five genres ("facetoface", "letters", "nineeleven", "oup", and "verbatim") are not.

The data are subdivided into five datasets:

The test sets are for scoring your model on Kaggle, so we'll skip those.

We'll train our model using train.tsv, and test it using dev_matched.tsv and dev_mismatched.tsv.

# Set up a processor function for {dlr} to load the data.
process_mnli <- function(source_file) {
  dataset_names <- c(
    "train",
    "dev_matched",
    "dev_mismatched",
    "test_matched",
    "test_mismatched"
  )
  # Also make those the names so purrr uses them.
  names(dataset_names) <- dataset_names

  mnli_tibbles <- purrr::map(
    dataset_names,
    function(this_dataset) {
      # We specify column types to make sure things come in as we expect.
      column_spec <- dplyr::case_when(
        stringr::str_starts(this_dataset, "dev_") ~ "iicccccccccccccc",
        stringr::str_starts(this_dataset, "test_") ~ "iiiccccccc",
        TRUE ~ "iicccccccccc"
      )
      raw_tibble <- readr::read_tsv(
        unz(source_file, fs::path("MNLI", this_dataset, ext = "tsv")),
        col_types = column_spec,
        # There are a couple lines that screw up if we include a quote
        # character.
        quote = ""
      )
      # If there are labels, standardize them, to make sure the factor levels
      # are always the same.
      if ("gold_label" %in% colnames(raw_tibble)) {
        raw_tibble$gold_label <- factor(
          raw_tibble$gold_label,
          levels = c("entailment", "neutral", "contradiction")
        )
      }
      return(
        dplyr::select(
          raw_tibble,
          -index,
          -promptID,
          -pairID,
          -dplyr::ends_with("_parse"),
          -dplyr::starts_with("label")
        )
      )
    }
  )

  return(mnli_tibbles)
}

# By default downloading large files often fails. Increase the timeout.
old_timeout <- options(timeout = 1000)

data_url <- "https://dl.fbaipublicfiles.com/glue/data/MNLI.zip"

mnli_tibbles <- dlr::read_or_cache(
  source_path = data_url,
  appname = "torchtransformers",
  process_f = process_mnli
)

# Restore the timeout.
options(old_timeout)

We need to set these datasets up for use with {luz}. We can use dataset_bert_pretrained() to process the train, matched, and mismatched datasets.

train_ds <- dataset_bert_pretrained(
  x = dplyr::select(
    mnli_tibbles$train,
    sentence1,
    sentence2
  ),
  y = mnli_tibbles$train$gold_label
)
test_matched_ds <- dataset_bert_pretrained(
  x = dplyr::select(
    mnli_tibbles$dev_matched, 
    sentence1,
    sentence2
  ),
  y = mnli_tibbles$dev_matched$gold_label
)
test_mismatched_ds <- dataset_bert_pretrained(
  x = dplyr::select(
    mnli_tibbles$dev_mismatched, 
    sentence1,
    sentence2
  ),
  y = mnli_tibbles$dev_mismatched$gold_label
)

Note that we do not tokenize the data at this point. We'll let the model trigger tokenization to make sure the data is in the format the model expects.

Model

We'll construct a model based on BERT, with a linear layer to score the input on the three label dimensions.

entailment_classifier <- torch::nn_module(
  "entailment_classifier",
  initialize = function(bert_type = "bert_tiny_uncased") {
    embedding_size <- config_bert(bert_type, "embedding_size")
    self$bert <- model_bert_pretrained(bert_type)
    # After pooled bert output, do a final dense layer.
    self$linear <- torch::nn_linear(
      in_features = embedding_size,
      out_features = 3L # 3 possible labels
    )
  },
  forward = function(x) {
    output <- self$bert(x)

    # Take the output embeddings from the last layer.
    output <- output$output_embeddings
    output <- output[[length(output)]]
    # Take the [CLS] token embedding for classification.
    output <- output[ , 1, ]
    # Apply the last dense layer to the pooled output.
    output <- self$linear(output)
    return(output)
  }
)

We fit the model using {luz}. We only fit for one epoch as a proof of concept.

torch::torch_manual_seed(123456)
fitted <- entailment_classifier %>% 
  luz::setup(
    loss = torch::nn_cross_entropy_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz::luz_metric_accuracy()
    )
  ) %>% 
  fit(
    train_ds, 
    epochs = 1, 
    callbacks = list(
      luz_callback_bert_tokenize(
        submodel_name = "bert", 
        n_tokens = 128L # We don't want the full 512 for this example.
      )
    ),
    valid_data = 0.1,
    dataloader_options = list(batch_size = 256L)
  )

Results

We predict the two test datasets, and measure the results.

predictions_matched <- fitted %>% 
  predict(
    test_matched_ds, 
    callbacks = list(
      luz_callback_bert_tokenize(
        submodel_name = "bert", 
        n_tokens = 128L
      )
    )
  ) %>% 
  torch::nnf_softmax(2) %>% 
  torch::torch_argmax(2)
predictions_matched <- predictions_matched$to(device = "cpu") %>% 
  torch::as_array()

dev_matched <- mnli_tibbles$dev_matched %>% 
  dplyr::mutate(
    .pred = factor(
      predictions_matched, levels = 1:3, labels = levels(gold_label)
    )
  )

yardstick::accuracy(dev_matched, gold_label, .pred)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.659

predictions_mismatched <- fitted %>% 
  predict(
    test_mismatched_ds, 
    callbacks = list(
      luz_callback_bert_tokenize(
        submodel_name = "bert", 
        n_tokens = 128L
      )
    )
  ) %>% 
  torch::nnf_softmax(2) %>% 
  torch::torch_argmax(2)
predictions_mismatched <- predictions_mismatched$to(device = "cpu") %>% 
  torch::as_array()

dev_mismatched <- mnli_tibbles$dev_mismatched %>% 
  dplyr::mutate(
    .pred = factor(
      predictions_mismatched, levels = 1:3, labels = levels(gold_label)
    )
  )

yardstick::accuracy(dev_mismatched, gold_label, .pred)
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.661

The published results for bert_tiny_uncased on these datasets are 0.72 and 0.73, so our results of 0.66 and 0.66 after a single epoch of fine-tuning are on track.



macmillancontentscience/torchtransformers documentation built on Aug. 6, 2023, 5:35 a.m.