predict.textspace: Predict using a Starspace model

View source: R/embed-all-the-things.R

predict.textspaceR Documentation

Predict using a Starspace model


The prediction functionality allows you to retrieve the following types of elements from a Starspace model:

  • generic: get general Starspace predictions in detail

  • labels: get similarity of your text to all the labels of the Starspace model

  • embedding: document embeddings of your text (shorthand for starspace_embedding)

  • knn: k-nearest neighbouring (most similar) elements of the model dictionary compared to your input text (shorthand for starspace_knn)


## S3 method for class 'textspace'
  type = c("generic", "labels", "knn", "embedding"),
  k = 5L,
  sep = " ",



an object of class textspace as returned by starspace or starspace_load_model


a data frame with columns doc_id and text or a character vector with text where the names of the character vector represent an identifier of that text


character string: either 'generic', 'labels', 'embedding', 'knn'. Defaults to 'generic'


integer with the number of predictions to make. Defaults to 5. Only used in case type is set to 'generic' or 'knn'


character string used to split newdata using boost::split. Only used in case type is set to 'generic'


optional, either a character vector of possible elements to predict or the path to a file in labelDoc format, containing basedocs which are set of possible things to predict, if different than the ones from the training data. Only used in case type is set to 'generic'


not used


The following is returned, depending on the argument type:

  • In case type is set to 'generic': a list, one for each row or element in newdata. Each list element is a list with elements

    • doc_id: the identifier of the text

    • text: the character string with the text

    • prediction: data.frame with columns label, label_starspace and similarity indicating the predicted label and the similarity of the text to the label

    • terms: a list with elements basedoc_index and basedoc_terms indicating the position in basedoc and the terms which are part of the dictionary which are used to find the similarity

  • In case type is set to 'labels': a data.frame is returned namely:
    The data.frame newdata where several columns are added, one for each label in the Starspace model. These columns contain the similarities of the text to the label. Similarities are computed with embedding_similarity indicating embedding similarities of the text compared to the labels using either cosine or dot product as was used during model training.

  • In case type is set to 'embedding':
    A matrix of document embeddings, one embedding for each text in newdata as returned by starspace_embedding. The rownames of this matrix are set to the document identifiers of newdata.

  • In case type is set to 'knn': a list of data.frames, one for each row or element in newdata
    Each of these data frames contains the columns doc_id, label, similarity and rank indicating the k-nearest neighbouring (most similar) elements of the model dictionary compared to your input text as returned by starspace_knn


data(dekamer, package = "ruimtehol")
dekamer$text <- strsplit(dekamer$question, "\\W")
dekamer$text <- lapply(dekamer$text, FUN = function(x) x[x != ""])
dekamer$text <- sapply(dekamer$text, 
                       FUN = function(x) paste(x, collapse = " "))

idx <- sample(nrow(dekamer), size = round(nrow(dekamer) * 0.9))
traindata <- dekamer[idx, ]
testdata <- dekamer[-idx, ]
model <- embed_tagspace(x = traindata$text, 
                        y = traindata$question_theme_main, 
                        early_stopping = 0.8,
                        dim = 10, minCount = 5)
scores <- predict(model, testdata)                        
scores <- predict(model, testdata, type = "labels")
emb <- predict(model, testdata[, c("doc_id", "text")], type = "embedding")
knn <- predict(model, testdata[1:5, c("doc_id", "text")], type = "knn", k=3)

## Not run: 
data(dekamer, package = "ruimtehol")
dekamer <- subset(dekamer, question_theme_main == "DEFENSIEBELEID")
x <- udpipe(dekamer$question, "dutch", tagger = "none", parser = "none", trace = 100)
x <- x[, c("doc_id", "sentence_id", "sentence", "token")]
model <- embed_sentencespace(x, dim = 15, epoch = 5, minCount = 5)
scores <- predict(model, "Wat zijn de cijfers qua doorstroming van 2016?", 
                  basedoc = unique(x$sentence), k = 3) 

#' ## clean up for cran
file.remove(list.files(pattern = ".udpipe$"))

## End(Not run)

ruimtehol documentation built on Jan. 7, 2023, 1:25 a.m.