View source: R/extract_features.R
extract_features | R Documentation |
Given example sentences (as a list of InputExample_EF
s), apply an
existing BERT model and capture certain output layers. (These could
potentially be used as features in downstream tasks.)
extract_features( examples, model = c("bert_base_uncased", "bert_base_cased", "bert_large_uncased", "bert_large_cased", "bert_large_uncased_wwm", "bert_large_cased_wwm", "bert_base_multilingual_cased", "bert_base_chinese", "scibert_scivocab_uncased", "scibert_scivocab_cased", "scibert_basevocab_uncased", "scibert_basevocab_cased"), ckpt_dir = NULL, vocab_file = find_vocab(ckpt_dir), bert_config_file = find_config(ckpt_dir), init_checkpoint = find_ckpt(ckpt_dir), output_file = NULL, max_seq_length = 128L, layer_indexes = -4:-1, batch_size = 2L, features = c("output", "attention"), verbose = FALSE )
examples |
List of |
model |
Character; which model checkpoint to use. If specified,
|
ckpt_dir |
Character; path to checkpoint directory. If specified, any
other checkpoint files required by this function ( |
vocab_file |
path to vocabulary file. File is assumed to be a text file, with one token per line, with the line number corresponding to the index of that token in the vocabulary. |
bert_config_file |
Character; the path to a json config file. |
init_checkpoint |
Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/". |
output_file |
(optional) Character; file path (stub) for writing output to. |
max_seq_length |
Integer; the maximum number of tokens that will be considered together. |
layer_indexes |
Integer vector; indexes (positive, or negative counting back from the end) indicating which layers to extract as "output features". The "zeroth" layer embeddings are the input embeddings vectors to the first layer. |
batch_size |
Integer; how many examples to process per batch. |
features |
Character; whether to return "output" (layer outputs, the default), "attention" (attention probabilities), or both. |
verbose |
Logical; if FALSE, suppresses most of the TensorFlow chatter
by temporarily setting the logging threshold to its highest level. If TRUE,
keeps the current logging threshold, which defaults to "WARN". To change
the logging threshold of the current session, run
|
A list with elements "output" (the layer outputs as a tibble) and/or "attention" (the attention weights as a tibble).
## Not run: BERT_PRETRAINED_DIR <- download_BERT_checkpoint("bert_base_uncased") examples <- c("I saw the branch on the bank.", "I saw the branch of the bank.") # Just specify checkpoint directory. feats <- extract_features( examples = examples, ckpt_dir = BERT_PRETRAINED_DIR ) # Can also just specify the model, if you have it downloaded. # In interactive mode, you'll be prompted to download the model if you do not # have it. feats <- extract_features( examples = examples, model = "bert_base_uncased" ) ## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.