knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

Packages ----------------------------------------------------------------

library(docformer)
library(luz)
library(torch)
library(sentencepiece)
library(tidyverse)

Datasets and loaders ----------------------------------------------------

# load a sentencepiece tokenizer and add a <mask> and <pad> missing token.
tok_model <- sentencepiece_load_model(system.file(package = "sentencepiece", "models/nl-fr-dekamer.model"))
# prepend tokenizer with mandatory tokens
tok_model$vocab_size <- tok_model$vocab_size + 2L
# Add <mask> and <pad>. Here <mask> is at id=0
tok_model$vocabulary <- rbind(data.frame(subword = c("<mask>", "<pad>")), 
                              tok_model$vocabulary["subword"]) %>% 
  tibble::rowid_to_column("id") %>%
  dplyr::mutate(id = id - 1)

# configure pdf repository
root <- "~/R/dataset/arxiv/" # dataset directory
arxiv_dataset <- torch::dataset(
  "arxiv_cs",
  initialize = function(root, split = "train", indexes = NULL, tokenizer = NULL, download = FALSE) {

    # datasets -------------------------------------------------
    data("arXiv_classification", package = "docformer")
    docs <- tibble(fullpath = list.files(path = root, full.names = T, recursive = T, pattern = "pdf$")) %>%
      mutate(path = fs::path_file(fullpath))

    self$docs <-  arXiv_classification %>%
      inner_join(docs, by = "path") %>% 
      select(-path, path = fullpath)

    self$tokenizer <- tokenizer


    if(split == "train") {
      self$docs <- self$docs %>% filter(set == "train")
    } else if(split == "test") {
      self$docs <- self$docs %>% filter(set == "test")
    }
  },


  .getitem = function(index) {

    force(index)
    sample <- self$docs[index, ]
    x <- create_features_from_doc(sample$path, tokenizer = self$tokenizer)

    return(list(x = x, y = sample$major, id = sample$path))
  },

  .length = function() {
    nrow(self$docs)
  }
)


train_ds <- arxiv_dataset(
  root,
  tokenizer = tok_model,
  download = FALSE,
  split = "train"
)

valid_ds <- arxiv_dataset(
  root,
  tokenizer = tok_model,
  download = FALSE,
  split = "valid"
)


train_dl <- torch::dataloader(train_ds, batch_size = 2, shuffle = TRUE)
valid_dl <- torch::dataloader(valid_ds, batch_size = 2)

Define the network ------------------------------------------------------

# tic()
# network_module <- docformer:::docformer_for_masked_lm(config, .mask_id(tok_model))
network_module <- docformer_pretrain(
  pretrained_model_name = "allenai/hvila-row-layoutlm-finetuned-docbank", 
  mask_id = .mask_id(tok_model))
# toc() # 30s

Train ---------------------------------------------

We train using the cross entropy loss. We could have used the dice loss

too, but it's harder to optimize.

fitted <- network_module %>%
  fit(train_dl, epochs = 10, valid_data = valid_dl)

plot(fitted)

Plot validation image ---------------------

# library(raster)
# preds <- predict(fitted, dataloader(dataset_subset(valid_ds, 2)))
# 
# mask <- as.array(torch_argmax(preds[1,..], 1)$to(device = "cpu"))
# mask <- raster::ratify(raster::raster(mask))
# 
# img <- raster::brick(as.array(valid_ds[2][[1]]$permute(c(2,3,1))))
# raster::plotRGB(img, scale = 1)
# plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)


cregouby/docformer documentation built on May 27, 2023, 11:19 p.m.