Whereas the vignette about GPT-2 presents a very popular way to calculate word probabilities using GPT-like models, masked models present an alternative, especially, when we just care about the final word following a certain context.
A masked language model (also called BERT-like, or encoder model) is a type of large language model that can be used to predict the content of a mask in a sentence. BERT is an example of a masked language model [see also @Devlinetal2018].
First load the following packages:
library(pangoling) library(tidytable) # fast alternative to dplyr
Notice the following potential pitfall. This would be a bad approach for making predictions in a masked model:
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK]") #> Processing using masked model 'bert-base-uncased/' ... #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK] . -0.0579 1 #> 2 The apple doesn't fall far from the [MASK] ; -3.21 1 #> 3 The apple doesn't fall far from the [MASK] ! -4.83 1 #> 4 The apple doesn't fall far from the [MASK] ? -5.33 1 #> 5 The apple doesn't fall far from the [MASK] ... -7.84 1 #> 6 The apple doesn't fall far from the [MASK] | -8.11 1 #> 7 The apple doesn't fall far from the [MASK] tree -8.76 1 #> 8 The apple doesn't fall far from the [MASK] - -9.69 1 #> 9 The apple doesn't fall far from the [MASK] ' -9.87 1 #> 10 The apple doesn't fall far from the [MASK] : -10.5 1 #> # ℹ 30,512 more rows
(The pretrained models and tokenizers will be downloaded from https://huggingface.co/ the first time they are used.)
The most common predictions are punctuation marks, because BERT uses the left and right context. In this case, the right context indicates that the mask is the final token of the sentence. More expected results are obtained in the following way:
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK].") #> Processing using masked model 'bert-base-uncased/' ... #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK]. tree -0.691 1 #> 2 The apple doesn't fall far from the [MASK]. ground -1.98 1 #> 3 The apple doesn't fall far from the [MASK]. sky -2.13 1 #> 4 The apple doesn't fall far from the [MASK]. table -4.02 1 #> 5 The apple doesn't fall far from the [MASK]. floor -4.31 1 #> 6 The apple doesn't fall far from the [MASK]. top -4.48 1 #> 7 The apple doesn't fall far from the [MASK]. ceiling -4.62 1 #> 8 The apple doesn't fall far from the [MASK]. window -4.87 1 #> 9 The apple doesn't fall far from the [MASK]. trees -4.94 1 #> 10 The apple doesn't fall far from the [MASK]. apple -4.95 1 #> # ℹ 30,512 more rows
We can mask several tokens as well (but bear in mind that this type of models are trained with only 10-15% of masks):
df_masks <- masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK][MASK]") #> Processing using masked model 'bert-base-uncased/' ... df_masks |> filter(mask_n == 1) #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK][MASK] tree -0.738 1 #> 2 The apple doesn't fall far from the [MASK][MASK] ground -1.72 1 #> 3 The apple doesn't fall far from the [MASK][MASK] sky -2.31 1 #> 4 The apple doesn't fall far from the [MASK][MASK] table -3.67 1 #> 5 The apple doesn't fall far from the [MASK][MASK] floor -4.47 1 #> 6 The apple doesn't fall far from the [MASK][MASK] top -4.67 1 #> 7 The apple doesn't fall far from the [MASK][MASK] ceiling -4.89 1 #> 8 The apple doesn't fall far from the [MASK][MASK] window -5.02 1 #> 9 The apple doesn't fall far from the [MASK][MASK] bush -5.02 1 #> 10 The apple doesn't fall far from the [MASK][MASK] vine -5.03 1 #> # ℹ 30,512 more rows df_masks |> filter(mask_n == 2) #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK][MASK] . -0.0570 2 #> 2 The apple doesn't fall far from the [MASK][MASK] ; -2.91 2 #> 3 The apple doesn't fall far from the [MASK][MASK] ! -7.33 2 #> 4 The apple doesn't fall far from the [MASK][MASK] ? -9.09 2 #> 5 The apple doesn't fall far from the [MASK][MASK] ... -11.9 2 #> 6 The apple doesn't fall far from the [MASK][MASK] , -12.4 2 #> 7 The apple doesn't fall far from the [MASK][MASK] - -12.8 2 #> 8 The apple doesn't fall far from the [MASK][MASK] | -13.3 2 #> 9 The apple doesn't fall far from the [MASK][MASK] so -13.4 2 #> 10 The apple doesn't fall far from the [MASK][MASK] : -13.9 2 #> # ℹ 30,512 more rows
We can also use BERT to examine the predictability of words assuming that both the left and right contexts are known:
(df_sent <- data.frame( left = c("The", "The"), critical = c("apple", "pear"), right = c( "doesn't fall far from the tree.", "doesn't fall far from the tree." ) )) #> left critical right #> 1 The apple doesn't fall far from the tree. #> 2 The pear doesn't fall far from the tree.
The function masked_targets_pred()
will give us the log-probability of the
target word (and will take care of summing the log-probabilities in case the
target is composed by several tokens).
df_sent <- df_sent %>% mutate(lp = masked_targets_pred( prev_contexts = left, targets = critical, after_contexts = right )) #> Processing using masked model 'bert-base-uncased/' ... #> Processing 1 batch(es) of 13 tokens. #> The [apple] doesn't fall far from the tree. #> Processing 1 batch(es) of 13 tokens. #> The [pear] doesn't fall far from the tree. #> *** df_sent #> # A tidytable: 2 × 4 #> left critical right lp #> <chr> <chr> <chr> <dbl> #> 1 The apple doesn't fall far from the tree. -4.68 #> 2 The pear doesn't fall far from the tree. -8.60
As expected (given the popularity of the proverb), "apple" is a more likely target word than "pear".
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.