predict.textspace: Predict using a Starspace model

Description Usage Arguments Value Examples

View source: R/embed-all-the-things.R

Description

The prediction functionality allows you to retrieve the following types of elements from a Starspace model:

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
## S3 method for class 'textspace'
predict(
  object,
  newdata,
  type = c("generic", "labels", "knn", "embedding"),
  k = 5L,
  sep = " ",
  basedoc,
  ...
)

Arguments

object

an object of class textspace as returned by starspace or starspace_load_model

newdata

a data frame with columns doc_id and text or a character vector with text where the names of the character vector represent an identifier of that text

type

character string: either 'generic', 'labels', 'embedding', 'knn'. Defaults to 'generic'

k

integer with the number of predictions to make. Defaults to 5. Only used in case type is set to 'generic' or 'knn'

sep

character string used to split newdata using boost::split. Only used in case type is set to 'generic'

basedoc

optional, either a character vector of possible elements to predict or the path to a file in labelDoc format, containing basedocs which are set of possible things to predict, if different than the ones from the training data. Only used in case type is set to 'generic'

...

not used

Value

The following is returned, depending on the argument type:

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
data(dekamer, package = "ruimtehol")
dekamer$text <- strsplit(dekamer$question, "\\W")
dekamer$text <- lapply(dekamer$text, FUN = function(x) x[x != ""])
dekamer$text <- sapply(dekamer$text, 
                       FUN = function(x) paste(x, collapse = " "))

idx <- sample(nrow(dekamer), size = round(nrow(dekamer) * 0.9))
traindata <- dekamer[idx, ]
testdata <- dekamer[-idx, ]
set.seed(123456789)
model <- embed_tagspace(x = traindata$text, 
                        y = traindata$question_theme_main, 
                        early_stopping = 0.8,
                        dim = 10, minCount = 5)
scores <- predict(model, testdata)                        
scores <- predict(model, testdata, type = "labels")
str(scores)
emb <- predict(model, testdata[, c("doc_id", "text")], type = "embedding")
knn <- predict(model, testdata[1:5, c("doc_id", "text")], type = "knn", k=3)


## Not run: 
library(udpipe)
data(dekamer, package = "ruimtehol")
dekamer <- subset(dekamer, question_theme_main == "DEFENSIEBELEID")
x <- udpipe(dekamer$question, "dutch", tagger = "none", parser = "none", trace = 100)
x <- x[, c("doc_id", "sentence_id", "sentence", "token")]
set.seed(123456789)
model <- embed_sentencespace(x, dim = 15, epoch = 5, minCount = 5)
scores <- predict(model, "Wat zijn de cijfers qua doorstroming van 2016?", 
                  basedoc = unique(x$sentence), k = 3) 
str(scores)

#' ## clean up for cran
file.remove(list.files(pattern = ".udpipe$"))

## End(Not run)

ruimtehol documentation built on Jan. 13, 2021, 8 p.m.