tests/testthat/test-build_hmm.R

# create test data
set.seed(123)
s <- 4
obs <- seqdef(matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10))

test_that("build_hmm returns object of class 'hmm'", {
  expect_error(
    model <- build_hmm(obs, n_states = s),
    NA
  )
  expect_s3_class(
    model,
    "hmm"
  )
  expect_error(
    build_hmm(obs, initial_probs = c(1, 0),
              transition_probs = diag(2),
              emission_probs = cbind(1, matrix(0, 2, s - 1))),
    NA
  )
})
test_that("build_hmm errors with incorrect dims", {
  expect_error(
    build_hmm(obs, initial_probs = c(1, 0),
                       transition_probs = diag(2),
                       emission_probs = diag(2)),
    "Number of columns in 'emission_probs' is not equal to the number of symbols."
  )
  expect_error(
    build_hmm(obs, initial_probs = c(1, 0, 0),
              transition_probs = diag(2),
              emission_probs = cbind(1, matrix(0, 2, s - 1))),
    "Length of 'initial_probs' is not equal to the number of states."
  )
  expect_error(
    build_hmm(obs, initial_probs = c(1, 0, 0),
              transition_probs = diag(3),
              emission_probs = cbind(1, matrix(0, 2, s - 1))),
    "Number of rows in 'emission_probs' is not equal to the number of states."
  )
})

test_that("build_hmm errors with incorrect observations", {
  expect_error(
    build_hmm(1, initial_probs = c(1, 0),
              transition_probs = diag(2),
              emission_probs = diag(2)),
    paste0("Argument 'observations' should a 'stslist' object created with ",
    "'seqdef' function, or a list of such objects in case of multichannel data."
    )
  )
})

test_that("build_hmm returns the correct number of states", {
  expect_error(
    model <- build_hmm(obs, n_states = s),
    NA
  )
  expect_equal(
    length(model$initial_probs),
    s
  )
  expect_equal(
    dim(model$transition_probs),
    c(s, s)
  )
  expect_equal(
    dim(model$emission_probs),
    c(s, s)
  )
})

test_that("build_hmm returns the correct probabilities", {
  model <- build_hmm(obs, n_states = s)
  expect_true(
    all(model$initial_probs >= 0)
  )
  expect_true(
    all(model$initial_probs <= 1)
  )
  expect_equal(sum(model$initial_probs), 1)

  expect_equal(
    rowSums(model$transition_probs),
    setNames(rep(1, s), paste("State", 1:s))
  )
  expect_equal(
    rowSums(model$emission_probs),
    setNames(rep(1, s), paste("State", 1:s))
  )
  expect_equal(colnames(model$emission_probs), letters[1:s])
  expect_true(all(model$transition_probs >= 0))
  expect_true(all(model$emission_probs >= 0))
  expect_true(all(model$transition_probs <= 1))
  expect_true(all(model$emission_probs <= 1))
})

Try the seqHMM package in your browser

Any scripts or data that you put into this service are public.

seqHMM documentation built on July 9, 2023, 6:35 p.m.