R/norm_viterbi.R

Defines functions norm_viterbi

Documented in norm_viterbi

#' Global decoding by the Viterbi algorithm
#'
#' This function takes in data x assumed to be generated by the HMM `hmm` and
#' outputs the most likely sequence of states that could have generated the
#' data using global decoding by the Viterbi algorithm.
#'
#' @param x The data to be fit with an HMM in the form of a 3D array. The
#'   first index (row) corresponds to time, the second (column) to the
#'   variable number, and the third (matrix number) to the subject number.
#' @param hmm A list of parameters that specify the normal HMM, including
#'   `num_states`, `num_variables`, `num_subjects`, `num_covariates`, `mu`,
#'   `sigma`, `gamma`, `delta`.
#' @param state_dep_dist_pooled A logical variable indiacting whether the
#'   state dependent distribution parameters `mu` and `sigma` should be
#'   treated as equal for all subjects.
#'
#' @return A matrix with each column containing the sequence of states for the
#'   given subject.
#' @export
norm_viterbi <- function(x, hmm, state_dep_dist_pooled = FALSE) {
  num_time       <- nrow(x)
  num_states     <- hmm$num_states
  num_variables  <- hmm$num_variables
  num_subjects   <- hmm$num_subjects
  num_covariates <- hmm$num_covariates
  state_probs    <- list()
  sequence       <- matrix(0, nrow = num_time, ncol = num_subjects)
  allprobs       <- norm_allprobs(num_states, num_variables,
                                  num_subjects, num_time, x, hmm,
                                  state_dep_dist_pooled = FALSE)

  for (i in 1:num_subjects) {
    prob                  <- allprobs[[i]]
    state_probs[[i]]      <- matrix(0, nrow = num_time, ncol = num_states)
    forward_probs         <- hmm$delta[[i]]*prob[1, ]
    state_probs[[i]][1, ] <- forward_probs/sum(forward_probs)

    for (t in 2:num_time) {
      if (num_covariates != 0) {
        forward_probs <- apply(state_probs[[i]][t - 1, ]*
                                 hmm$gamma[[i]][, , t], 2, max)*prob[t, ]
      } else {
        forward_probs <- apply(state_probs[[i]][t - 1, ]*
                                 hmm$gamma[[i]], 2, max)*prob[t, ]
      }
      state_probs[[i]][t, ] <- forward_probs/sum(forward_probs)
    }
    sequence[num_time, i] <- which.max(state_probs[[i]][num_time, ])
    for (t in (num_time - 1):1){
      if (num_covariates != 0) {
        sequence[t, i] <- which.max(hmm$gamma[[i]][, sequence[t + 1], t]*
                                      state_probs[[i]][t, ])
      } else {
        sequence[t, i] <- which.max(hmm$gamma[[i]][, sequence[t + 1]]*
                                      state_probs[[i]][t, ])
      }
    }
  }
  sequence
}
simonecollier/lizardHMM documentation built on Dec. 23, 2021, 2:24 a.m.