Nothing
context("tests of predict method for tidylda")
dtm <- nih_sample_dtm
d1 <- dtm[1:50, ]
d2 <- dtm[51:100, ]
# make sure we have different vocabulary for each data set
d1 <- d1[, Matrix::colSums(d1) > 0]
d2 <- d2[, Matrix::colSums(d2) > 0]
lda <- tidylda(
data = d1,
k = 4,
iterations = 20, burnin = 10,
alpha = 0.1, eta = 0.05,
optimize_alpha = TRUE,
calc_likelihood = TRUE,
calc_r2 = TRUE,
return_data = FALSE,
verbose = FALSE
)
### Tests for predictions ----
test_that("can make predictions without error", {
# one row gibbs with burnin
p <- predict(
object = lda,
new_data = d2[1, ],
method = "gibbs",
iterations = 20,
burnin = 10,
verbose = FALSE
)
expect_equal(nrow(p), 1)
expect_equal(ncol(p), ncol(lda$theta))
expect_setequal(colnames(p), colnames(lda$theta))
# multi-row gibbs with burnin
p <- predict(
object = lda,
new_data = d2,
method = "gibbs",
iterations = 20,
burnin = 10,
verbose = FALSE
)
expect_equal(nrow(p), nrow(d2))
expect_equal(ncol(p), ncol(lda$theta))
expect_setequal(colnames(p), colnames(lda$theta))
# single row dot method
p <- predict(object = lda, new_data = d2[1, ], method = "dot")
expect_equal(nrow(p), 1)
expect_equal(ncol(p), ncol(lda$theta))
expect_setequal(colnames(p), colnames(lda$theta))
# multi-row dot method
p <- predict(object = lda, new_data = d2, method = "dot")
expect_equal(nrow(p), nrow(d2))
expect_equal(ncol(p), ncol(lda$theta))
expect_setequal(colnames(p), colnames(lda$theta))
# multi row parallel
# (no longer parallel, but checks that threads arg doesn't cause an error)
p <- predict(
object = lda,
new_data = d2,
method = "gibbs",
iterations = 20,
burnin = 10,
threads = 2,
verbose = FALSE
)
expect_true(inherits(p, "matrix"))
# single row class with dot
p <- predict(
object = lda,
new_data = d2[1, ],
type = "class",
method = "dot"
)
expect_equal(length(p), 1)
expect_true(inherits(p, "integer"))
# multi row class with dot
p <- predict(
object = lda,
new_data = d2,
type = "class",
method = "dot"
)
expect_equal(length(p), nrow(d2))
expect_true(inherits(p, "integer"))
# single row distribution with gibbs
p <- predict(
object = lda,
new_data = d2[1, ],
type = "distribution",
method = "gibbs",
iterations = 20,
burnin = 10,
times = 10,
threads = 1,
verbose = FALSE
)
expect_true(inherits(p, "tbl"))
# multi row distribution with dot
p <- predict(
object = lda,
new_data = d2,
type = "distribution",
method = "dot",
times = 10
)
expect_true(inherits(p, "tbl"))
})
test_that("malformed args in predict throw errors", {
# threads > nrow(dtm)
expect_message(
predict(
object = lda,
new_data = d2,
method = "gibbs",
iterations = 20,
burnin = 10,
threads = nrow(d2) + 2,
verbose = FALSE
), label = "threads > nrow(dtm)"
)
# no iterations specified
expect_error(
predict(object = lda, new_data = d2, method = "gibbs")
)
# burnin >= iterations
expect_error(
predict(object = lda, new_data = d2, method = "gibbs", iterations = 5, burnin = 6)
)
# incorrect method
expect_error(
predict(object = lda, new_data = d2, method = "oops")
)
# no overlap in vocabulary throws warning on "dot" by default
nd <- numeric(10)
names(nd) <- seq_along(nd) # numbers means no vocab overlap
expect_warning(
predict(object = lda, new_data = nd, method = "dot")
)
# no overlap in vocabulary doesn't throw warning on "dot" if specified
expect_message(
predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "zero")
)
# no overlap in vocabulary sets every topic to 1/k and no message or warning
p <- predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "uniform")
expect_equal(mean(p), 1 / nrow(lda$beta))
# no_common_tokens has illegal value
expect_error(
predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "WRONG!")
)
# type misspecified
expect_error(
predict(object = lda, new_data = d2, type = "blah", method = "dot")
)
# times misspecified while type = "distribution"
expect_error(
predict(
object = lda,
new_data = d2,
type = "distribution",
method = "dot",
times = NA
)
)
expect_error(
predict(
object = lda,
new_data = d2,
type = "distribution",
method = "dot",
times = "yeet"
)
)
expect_error(
predict(
object = lda,
new_data = d2,
type = "distribution",
method = "dot",
times = 0
)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.