tests/testthat/test-predict.tidylda.R

context("tests of predict method for tidylda")

dtm <- nih_sample_dtm

d1 <- dtm[1:50, ]

d2 <- dtm[51:100, ]

# make sure we have different vocabulary for each data set
d1 <- d1[, Matrix::colSums(d1) > 0]

d2 <- d2[, Matrix::colSums(d2) > 0]

lda <- tidylda(
  data = d1,
  k = 4,
  iterations = 20, burnin = 10,
  alpha = 0.1, eta = 0.05,
  optimize_alpha = TRUE,
  calc_likelihood = TRUE,
  calc_r2 = TRUE,
  return_data = FALSE,
  verbose = FALSE
)

### Tests for predictions ----

test_that("can make predictions without error", {
  # one row gibbs with burnin
  p <- predict(
    object = lda, 
    new_data = d2[1, ], 
    method = "gibbs", 
    iterations = 20, 
    burnin = 10,
    verbose = FALSE
  )
  
  expect_equal(nrow(p), 1)
  
  expect_equal(ncol(p), ncol(lda$theta))
  
  expect_setequal(colnames(p), colnames(lda$theta))
  
  # multi-row gibbs with burnin
  p <- predict(
    object = lda, 
    new_data = d2, 
    method = "gibbs", 
    iterations = 20, 
    burnin = 10,
    verbose = FALSE
  )
  
  expect_equal(nrow(p), nrow(d2))
  
  expect_equal(ncol(p), ncol(lda$theta))
  
  expect_setequal(colnames(p), colnames(lda$theta))
  
  # single row dot method
  p <- predict(object = lda, new_data = d2[1, ], method = "dot")
  
  expect_equal(nrow(p), 1)
  
  expect_equal(ncol(p), ncol(lda$theta))
  
  expect_setequal(colnames(p), colnames(lda$theta))
  
  # multi-row dot method
  p <- predict(object = lda, new_data = d2, method = "dot")
  
  expect_equal(nrow(p), nrow(d2))
  
  expect_equal(ncol(p), ncol(lda$theta))
  
  expect_setequal(colnames(p), colnames(lda$theta))
  
  # multi row parallel 
  # (no longer parallel, but checks that threads arg doesn't cause an error)
  p <- predict(
    object = lda, 
    new_data = d2, 
    method = "gibbs", 
    iterations = 20, 
    burnin = 10,
    threads = 2,
    verbose = FALSE
  )  
  
  expect_true(inherits(p, "matrix"))
  
  # single row class with dot
  p <- predict(
    object = lda,
    new_data = d2[1, ],
    type = "class",
    method = "dot"
  )
  
  expect_equal(length(p), 1)
  
  expect_true(inherits(p, "integer"))
  
  # multi row class with dot
  p <- predict(
    object = lda,
    new_data = d2,
    type = "class",
    method = "dot"
  )
  
  expect_equal(length(p), nrow(d2))
  
  expect_true(inherits(p, "integer"))
  
  # single row distribution with gibbs
  p <- predict(
    object = lda, 
    new_data = d2[1, ],
    type = "distribution",
    method = "gibbs", 
    iterations = 20, 
    burnin = 10,
    times = 10,
    threads = 1,
    verbose = FALSE
  )  
  
  expect_true(inherits(p, "tbl"))
  
  # multi row distribution with dot
  p <- predict(
    object = lda,
    new_data = d2,
    type = "distribution",
    method = "dot",
    times = 10
  )
  expect_true(inherits(p, "tbl"))
  
})



test_that("malformed args in predict throw errors", {
  
  # threads > nrow(dtm)
  expect_message(
    predict(
      object = lda, 
      new_data = d2, 
      method = "gibbs", 
      iterations = 20, 
      burnin = 10,
      threads = nrow(d2) + 2,
      verbose = FALSE
    ), label = "threads > nrow(dtm)"
  )
  # no iterations specified
  expect_error(
    predict(object = lda, new_data = d2, method = "gibbs")
  )
  
  # burnin >= iterations
  expect_error(
    predict(object = lda, new_data = d2, method = "gibbs", iterations = 5, burnin = 6)
  )
  
  # incorrect method
  expect_error(
    predict(object = lda, new_data = d2, method = "oops")
  )
  
  # no overlap in vocabulary throws warning on "dot" by default
  nd <- numeric(10)
  
  names(nd) <- seq_along(nd) # numbers means no vocab overlap
  
  expect_warning(
    predict(object = lda, new_data = nd, method = "dot")
  )
  # no overlap in vocabulary doesn't throw warning on "dot" if specified
  expect_message(
    predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "zero")
  )
  
  # no overlap in vocabulary sets every topic to 1/k and no message or warning
  p <- predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "uniform")
  
  expect_equal(mean(p), 1 / nrow(lda$beta))
  
  # no_common_tokens has illegal value
  expect_error(
    predict(object = lda, new_data = nd, method = "dot", no_common_tokens = "WRONG!")
  )
  
  # type misspecified
  expect_error(
    predict(object = lda, new_data = d2, type = "blah", method = "dot")
  )
  
  # times misspecified while type = "distribution"
  expect_error(
    predict(
      object = lda, 
      new_data = d2, 
      type = "distribution", 
      method = "dot",
      times = NA
    )
  )
  
  expect_error(
    predict(
      object = lda, 
      new_data = d2, 
      type = "distribution", 
      method = "dot",
      times = "yeet"
    )
  )
  
  expect_error(
    predict(
      object = lda, 
      new_data = d2, 
      type = "distribution", 
      method = "dot",
      times = 0
    )
  )
  
})

Try the tidylda package in your browser

Any scripts or data that you put into this service are public.

tidylda documentation built on July 26, 2023, 5:34 p.m.