transformer: Load a Transformer model

Description Usage Arguments Value Examples

View source: R/embed.R

Description

Load a Transformer model stored on disk

Usage

1
2
3
4
5
transformer(
  model_name,
  architecture = "BERT",
  path = system.file(package = "golgotha", "models")
)

Arguments

model_name

character string of the chosen model within the architecture family. E.g. 'bert-base-uncased', 'bert-base-multilingual-uncased', 'bert-base-multilingual-cased', 'bert-base-dutch-cased' for 'BERT' architecture family. Defaults to 'bert-base-multilingual-uncased'.

architecture

character string of the model architecture family name. Currently supported architecture are 'BERT', 'GPT', 'GPT-2', 'CTRL', 'Transformer-XL', 'XLNet', 'XLM', 'DistilBERT', 'RoBERTa' and 'XLM-RoBERTa'. Defaults to 'BERT'

path

path to a directory on disk where the model is stored

Value

an object of class Transformer

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
transformer_download_model("bert-base-multilingual-uncased")
model <- transformer("bert-base-multilingual-uncased")

x <- data.frame(doc_id = c("doc_1", "doc_2"),
                text = c("provide some words to embed", "another sentence of text"),
                stringsAsFactors = FALSE)
predict(model, x, type = "tokenise")
embedding <- predict(model, x, type = "embed-sentence")
dim(embedding)
embedding <- predict(model, x, type = "embed-token")
str(embedding)



model_dir <- file.path(getwd(), "inst", "models")
transformer_download_model(architecture = "DistilBERT",
                           model_name = "distilbert-base-multilingual-cased",
                           path = model_dir)
path  <- file.path(getwd(), "inst", "models", "distilbert-base-multilingual-cased")
model <- transformer(model_name = "distilbert-base-multilingual-cased",
                     architecture = "DistilBERT", path = path)
predict(model, x, type = "tokenise")
embedding <- predict(model, x, type = "embed-sentence")
dim(embedding)
embedding <- predict(model, x, type = "embed-token")
str(embedding)


unlink(file.path(system.file(package = "golgotha", "models"),
       "bert-base-multilingual-uncased"), recursive = TRUE)
unlink(file.path(system.file(package = "golgotha", "models"),
       "bert-base-multilingual-cased"), recursive = TRUE)

bnosac/golgotha documentation built on May 28, 2020, 4:06 a.m.