predict.tidylda: Get predictions from a Latent Dirichlet Allocation model

View source: R/predict.tidylda.R

predict.tidyldaR Documentation

Get predictions from a Latent Dirichlet Allocation model

Description

Obtains predictions of topics for new documents from a fitted LDA model

Usage

## S3 method for class 'tidylda'
predict(
  object,
  new_data,
  type = c("prob", "class", "distribution"),
  method = c("gibbs", "dot"),
  iterations = NULL,
  burnin = -1,
  no_common_tokens = c("default", "zero", "uniform"),
  times = 100,
  threads = 1,
  verbose = TRUE,
  ...
)

Arguments

object

a fitted object of class tidylda

new_data

a DTM or TCM of class dgCMatrix or a numeric vector

type

one of "prob", "class", or "distribution". Defaults to "prob".

method

one of either "gibbs" or "dot". If "gibbs" Gibbs sampling is used and iterations must be specified.

iterations

If method = "gibbs", an integer number of iterations for the Gibbs sampler to run. A future version may include automatic stopping criteria.

burnin

If method = "gibbs", an integer number of burnin iterations. If burnin is greater than -1, the entries of the resulting "theta" matrix are an average over all iterations greater than burnin. Behavior is the same as documented in tidylda.

no_common_tokens

behavior when encountering documents that have no tokens in common with the model. Options are "default", "zero", or "uniform". See 'details', below for explanation of behavior.

times

Integer, number of samples to draw if type = "distribution". Ignored if type is "class" or "prob". Defaults to 100.

threads

Number of parallel threads, defaults to 1. Note: currently ignored; only single-threaded prediction is implemented.

verbose

Logical. Do you want to print a progress bar out to the console? Only active if method = "gibbs". Defaults to TRUE.

...

Additional arguments, currently unused

Details

If predict.tidylda encounters documents that have no tokens in common with the model in object it will engage in one of three behaviors based on the setting of no_common_tokens.

default (the default) sets all topics to 0 for offending documents. This enables continued computations downstream in a way that NA would not. However, if no_common_tokens == "default", then predict.tidylda will emit a warning for every such document it encounters.

zero has the same behavior as default but it emits a message instead of a warning.

uniform sets all topics to 1/k for every topic for offending documents. it does not emit a warning or message.

Value

type gives different outputs depending on whether the user selects "prob", "class", or "distribution". If "prob", the default, returns a a "theta" matrix with one row per document and one column per topic. If "class", returns a vector with the topic index of the most likely topic in each document. If "distribution", returns a tibble with one row per parameter per sample. Number of samples is set by the times argument.

Examples


# load some data
data(nih_sample_dtm)

# fit a model
set.seed(12345)

m <- tidylda(
  data = nih_sample_dtm[1:20, ], k = 5,
  iterations = 200, burnin = 175
)

str(m)

# predict on held-out documents using gibbs sampling "fold in"
p1 <- predict(m, nih_sample_dtm[21:100, ],
  method = "gibbs",
  iterations = 200, burnin = 175
)

# predict on held-out documents using the dot product
p2 <- predict(m, nih_sample_dtm[21:100, ], method = "dot")

# compare the methods
barplot(rbind(p1[1, ], p2[1, ]), beside = TRUE, col = c("red", "blue"))

# predict classes on held out documents
p3 <- predict(m, nih_sample_dtm[21:100, ],
  method = "gibbs",
  type = "class",
  iterations = 100, burnin = 75
)

# predict distribution on held out documents
p4 <- predict(m, nih_sample_dtm[21:100, ],
  method = "gibbs",
  type = "distribution",
  iterations = 100, burnin = 75,
  times = 10
)


tidylda documentation built on July 26, 2023, 5:34 p.m.