Description Usage Arguments Details Value Author(s) Examples
View source: R/eval_sbo_predictor.R
Evaluate next-word predictions based on Stupid Back-off N-gram model on a test corpus.
1 | eval_sbo_predictor(model, test, L = attr(model, "L"))
|
model |
a |
test |
a character vector. Perform a single prediction on each entry of this vector (see details). |
L |
Maximum number of predictions for each input sentence
(maximum allowed is |
This function allows to obtain information on the quality of
Stupid Back-off model predictions, such as next-word prediction accuracy,
or the word-rank distribution of correct prediction, by direct test against
a test set corpus. For a reasonable estimate of prediction accuracy, the
different entries of the test
vector should be uncorrelated
documents (e.g. separate tweets, as in the twitter_test
example dataset).
More in detail, eval_sbo_predictor
performs the following operations:
Sample a single sentence from each entry of the character vector
test
.
Sample a single $N$-gram from each sentence obtained in the previous step.
Predict next words from the $(N-1)$-gram prefix.
Return all predictions, together with the true word completions.
A tibble, containing the input $(N-1)$-grams, the true completions, the predicted completions and a column indicating whether one of the predictions were correct or not.
Valerio Gherardi
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | # Evaluating next-word predictions from a Stupid Back-off N-gram model
if (suppressMessages(require(dplyr) && require(ggplot2))) {
p <- sbo_predictor(twitter_predtable)
set.seed(840) # Set seed for reproducibility
test <- sample(twitter_test, 500)
eval <- eval_sbo_predictor(p, test)
## Compute three-word accuracies
eval %>% summarise(accuracy = sum(correct)/n()) # Overall accuracy
eval %>% # Accuracy for in-sentence predictions
filter(true != "<EOS>") %>%
summarise(accuracy = sum(correct) / n())
## Make histogram of word-rank distribution for correct predictions
dict <- attr(twitter_predtable, "dict")
eval %>%
filter(correct, true != "<EOS>") %>%
transmute(rank = match(true, table = dict)) %>%
ggplot(aes(x = rank)) + geom_histogram(binwidth = 30)
}
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.