This example is a port of 'Text classification from scratch' from Keras documentation by Mark Omerick and François Chollet.
First we implement a torch dataset that downloads and pre-process the data. The initialize method is called when we instantiate a dataset. Our implementation:
root
directory.root
.We also implement the .getitem
method that is used to extract a single
element from the dataset and pre-process the file contents.
library(torch) library(tok) library(luz) vocab_size <- 20000 # maximum number of items in the vocabulary output_length <- 500 # padding and truncation length. embedding_dim <- 128 # size of the embedding vectors imdb_dataset <- dataset( initialize = function(output_length, vocab_size, root, split = "train", download = TRUE) { url <- "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" fpath <- file.path(root, "aclImdb") # download if file doesn't exist yet if (!dir.exists(fpath) && download) { # download into tempdir, then extract and move to the root dir withr::with_tempfile("file", { download.file(url, file) untar(file, exdir = root) }) } # now list files for the split self$data <- rbind( data.frame( fname = list.files(file.path(fpath, split, "pos"), full.names = TRUE), y = 1 ), data.frame( fname = list.files(file.path(fpath, split, "neg"), full.names = TRUE), y = 0 ) ) # train a tokenizer on the train data (if one doesn't exist yet) tokenizer_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json")) if (!file.exists(tokenizer_path)) { self$tok <- tok::tokenizer$new(tok::model_bpe$new()) self$tok$pre_tokenizer <- tok::pre_tokenizer_whitespace$new() files <- list.files(file.path(fpath, "train"), recursive = TRUE, full.names = TRUE) self$tok$train(files, tok::trainer_bpe$new(vocab_size = vocab_size)) self$tok$save(tokenizer_path) } else { self$tok <- tok::tokenizer$from_file(tokenizer_path) } self$tok$enable_padding(length = output_length) self$tok$enable_truncation(max_length = output_length) }, .getitem = function(i) { item <- self$data[i,] # takes item i, reads the file content into a char string # then makes everything lower case and removes html + punctuaction # next uses the tokenizer to encode the text. text <- item$fname %>% readr::read_file() %>% stringr::str_to_lower() %>% stringr::str_replace_all("<br />", " ") %>% stringr::str_remove_all("[:punct:]") %>% self$tok$encode() list( x = text$ids + 1L, y = item$y ) }, .length = function() { nrow(self$data) } ) train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train") test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test")
We now define the model we want to train. The model is a 1D convnet starting with an embedding layer and we plug a classifier at the output.
model <- nn_module( initialize = function(vocab_size, embedding_dim) { self$embedding <- nn_sequential( nn_embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim), nn_dropout(0.5) ) self$convs <- nn_sequential( nn_conv1d(embedding_dim, 128, kernel_size = 7, stride = 3, padding = "valid"), nn_relu(), nn_conv1d(128, 128, kernel_size = 7, stride = 3, padding = "valid"), nn_relu(), nn_adaptive_max_pool2d(c(128, 1)) # reduces the length dimension ) self$classifier <- nn_sequential( nn_flatten(), nn_linear(128, 128), nn_relu(), nn_dropout(0.5), nn_linear(128, 1) ) }, forward = function(x) { emb <- self$embedding(x) out <- emb$transpose(2, 3) %>% self$convs() %>% self$classifier() # we drop the last so we get (B) instead of (B, 1) out$squeeze(2) } ) # test the model for a single example batch # m <- model(vocab_size, embedding_dim) # x <- torch_randint(1, 20000, size = c(32, 500), dtype = "int") # m(x)
We can finally train the model:
fitted_model <- model %>% setup( loss = nnf_binary_cross_entropy_with_logits, optimizer = optim_adam, metrics = luz_metric_binary_accuracy_with_logits() ) %>% set_hparams(vocab_size = vocab_size, embedding_dim = embedding_dim) %>% fit(train_ds, epochs = 3)
We can finally obtain the metrics on the test dataset:
fitted_model %>% evaluate(test_ds)
Remember that in order to predict for texts, we need make the same pre-processing as used in the dataset definition.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.