R/baum_welch.R

Defines functions baum_welch

#' Title
#'
#' @param data
#' @param init_p
#' @param init_trn_mtx
#' @param init_response
#' @param init_trn_model_data
#' @param init_trn_model_formula
#' @param maxit
#' @param tol
#' @param crit
#' @param random.start
#' @param verbose
#' @param na.allow
#' @param ...
#'
#' @return
#' @export
#'
#' @examples
baum_welch <- function(data,
                       init_p = c(0.8, 0.2),
                       init_trn_mtx = init_trn_mtx,
                       init_response = list(
                         mean = matrix(rep(0, ncol(data)), ncol = ncol(data)),
                         cov_mtx = diag(ncol(data))
                       ),
                       init_trn_model_data,
                       init_trn_model_formula,
                       maxit = 100,
                       tol = 1e-8,
                       crit = c("relative","absolute"),
                       random.start = TRUE,
                       verbose = FALSE,
                       na.allow = TRUE,
                       ...) {
  crit <- match.arg(crit)
  tol <- 1e-8
  store_iterations <- TRUE

  n_states <- ncol(init_trn_mtx)
  n_factors <- ncol(data)
  n_steps <- nrow(data)

  # Make link functions
  logit_link_fun <- make.link("logit")

  base <- 1
  trn_model_dist_type <- list()
  trn_model_dist_type$linkinv <- depmixS4::mlogit()$linkinv

  trn_model_dist_type$linkfun <- function(x, base = 1) {
    log(x / sum(x))
  }

  # Initialize trn_model
  trn_mtx <- init_trn_mtx
  trn_model_x <- model.matrix(init_trn_model_formula, init_trn_model_data)

  # Initialize ts model
  init_p <- matrix(init_p, nrow = 1)
  init_dens <- sapply(1:n_states, function(x) {
    mvtnorm::dmvnorm(data, init_response$mean, init_response$cov_mtx)
  })

  # Initial fb
  fbo <- rkHMM::forward_backward(p_state_0 = init_p, trn_mtx = t(init_trn_mtx), p_obs = init_dens)
  if (store_iterations) {
    em_history <- vector("list", length = maxit)
    em_history[[1]] <- list(
      initial_p_state = init_p,
      trn_mtx = init_trn_mtx,
      response = list(
        init_response,
        init_response
      ),
      log_like = fbo$logLike
    )
  }
  LL.old <- fbo$logLike

  factor_density <- vector("list", length = n_states)
  response_list <- vector("list", length = n_states)
  converge <- FALSE
  for (j in 0:maxit) {
    trm <- matrix(0, n_states, n_states)
    for (i in 1:n_states) {
      gamma_total <- sum(fbo$gamma[-nrow(fbo$gamma), i])
      if (gamma_total == 0) {
        trm[i,] <- trn_mtx[ , i]
      } else {
        for(k in 1:n_states) {
          trm[i, k] <- sum(fbo$xi[-n_steps, k, i]) / gamma_total
        }
      }

      trn_mtx_coef <- c(0, logit_link_fun$linkfun(trm[i, ])[2])
      trn_mtx[ , i] <- trn_model_dist_type$linkinv(
        trn_model_x %*% trn_mtx_coef,
        base = base
      )
    }

    for (i in 1:n_states) {
      if (sum(fbo$gamma[,i]) > 0) {
        x <- matrix(1, n_steps)
        y <- data
        model_fit <- lm.wfit(x, y, w = fbo$gamma[, i])
        mean <- model_fit$coefficients
        sigma <- cov.wt(model_fit$residuals, fbo$gamma[, i])$cov
        response_list[[i]] <- list(mean = mean, cov_mtx = sigma)

        factor_density[[i]] <- mvtnorm::dmvnorm(data, model_fit$coefficients, sigma)
      }
    }

    if (store_iterations) {
      em_history[[j + 1]] <- list(
        trn_mtx = trn_mtx,
        response = response_list,
        log_like = fbo$logLike
      )
    }

    y <- fbo$gamma[1, , drop = FALSE]

    fbo <- rkHMM::forward_backward(
      p_state_0 = init_p,
      trn_mtx = t(trn_mtx),
      p_obs = do.call(cbind, factor_density)
    )

    if (fbo$logLike >= LL.old) {
      converge <- (crit == "absolute" && fbo$logLike - LL.old < tol) ||
        (crit == "relative" && (fbo$logLike - LL.old) / abs(LL.old)  < tol)
      if (converge) {
        cat("converged at iteration", j, "with logLik:", fbo$logLike, "\n")
        break
      }
    } else {
      # this should not really happen...
      if (j > 0 && (LL.old - fbo$logLike) > tol)
        stop("likelihood decreased on iteration ", j, "with rk model")
    }
    LL.old <- fbo$logLike
  }

  if (store_iterations) {
    em_history <- em_history[c(1:(j + 1))]
  }

  if (converge) {
    message <- switch(
      crit,
      relative = "Log likelihood converged to within tol. (relative change)",
      absolute = "Log likelihood converged to within tol. (absolute change)"
    )
  } else
    message <- "'maxit' iterations reached in EM without convergence."

  list(
    y = y,
    trn_mtx = trn_mtx,
    response = response_list
  )
}
ricky-kotecha/rkHMM documentation built on May 4, 2020, 12:08 a.m.