knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(docformer) library(luz) library(torch) library(sentencepiece) library(tidyverse)
# load a sentencepiece tokenizer and add a <mask> and <pad> missing token. tok_model <- sentencepiece_load_model(system.file(package = "sentencepiece", "models/nl-fr-dekamer.model")) # prepend tokenizer with mandatory tokens tok_model$vocab_size <- tok_model$vocab_size + 2L # Add <mask> and <pad>. Here <mask> is at id=0 tok_model$vocabulary <- rbind(data.frame(subword = c("<mask>", "<pad>")), tok_model$vocabulary["subword"]) %>% tibble::rowid_to_column("id") %>% dplyr::mutate(id = id - 1) # configure pdf repository root <- "~/R/dataset/arxiv/" # dataset directory
arxiv_dataset <- torch::dataset( "arxiv_cs", initialize = function(root, split = "train", indexes = NULL, tokenizer = NULL, download = FALSE) { # datasets ------------------------------------------------- data("arXiv_classification", package = "docformer") docs <- tibble(fullpath = list.files(path = root, full.names = T, recursive = T, pattern = "pdf$")) %>% mutate(path = fs::path_file(fullpath)) self$docs <- arXiv_classification %>% inner_join(docs, by = "path") %>% select(-path, path = fullpath) self$tokenizer <- tokenizer if(split == "train") { self$docs <- self$docs %>% filter(set == "train") } else if(split == "test") { self$docs <- self$docs %>% filter(set == "test") } }, .getitem = function(index) { force(index) sample <- self$docs[index, ] x <- create_features_from_doc(sample$path, tokenizer = self$tokenizer) return(list(x = x, y = sample$major, id = sample$path)) }, .length = function() { nrow(self$docs) } ) train_ds <- arxiv_dataset( root, tokenizer = tok_model, download = FALSE, split = "train" ) valid_ds <- arxiv_dataset( root, tokenizer = tok_model, download = FALSE, split = "valid" ) train_dl <- torch::dataloader(train_ds, batch_size = 2, shuffle = TRUE) valid_dl <- torch::dataloader(valid_ds, batch_size = 2)
# tic() # network_module <- docformer:::docformer_for_masked_lm(config, .mask_id(tok_model)) network_module <- docformer_pretrain( pretrained_model_name = "allenai/hvila-row-layoutlm-finetuned-docbank", mask_id = .mask_id(tok_model)) # toc() # 30s
fitted <- network_module %>% fit(train_dl, epochs = 10, valid_data = valid_dl) plot(fitted)
# library(raster) # preds <- predict(fitted, dataloader(dataset_subset(valid_ds, 2))) # # mask <- as.array(torch_argmax(preds[1,..], 1)$to(device = "cpu")) # mask <- raster::ratify(raster::raster(mask)) # # img <- raster::brick(as.array(valid_ds[2][[1]]$permute(c(2,3,1)))) # raster::plotRGB(img, scale = 1) # plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.