#' Title
#'
#' @param data
#' @param init_p
#' @param init_trn_mtx
#' @param init_response
#' @param init_trn_model_data
#' @param init_trn_model_formula
#' @param maxit
#' @param tol
#' @param crit
#' @param random.start
#' @param verbose
#' @param na.allow
#' @param ...
#'
#' @return
#' @export
#'
#' @examples
baum_welch <- function(data,
init_p = c(0.8, 0.2),
init_trn_mtx = init_trn_mtx,
init_response = list(
mean = matrix(rep(0, ncol(data)), ncol = ncol(data)),
cov_mtx = diag(ncol(data))
),
init_trn_model_data,
init_trn_model_formula,
maxit = 100,
tol = 1e-8,
crit = c("relative","absolute"),
random.start = TRUE,
verbose = FALSE,
na.allow = TRUE,
...) {
crit <- match.arg(crit)
tol <- 1e-8
store_iterations <- TRUE
n_states <- ncol(init_trn_mtx)
n_factors <- ncol(data)
n_steps <- nrow(data)
# Make link functions
logit_link_fun <- make.link("logit")
base <- 1
trn_model_dist_type <- list()
trn_model_dist_type$linkinv <- depmixS4::mlogit()$linkinv
trn_model_dist_type$linkfun <- function(x, base = 1) {
log(x / sum(x))
}
# Initialize trn_model
trn_mtx <- init_trn_mtx
trn_model_x <- model.matrix(init_trn_model_formula, init_trn_model_data)
# Initialize ts model
init_p <- matrix(init_p, nrow = 1)
init_dens <- sapply(1:n_states, function(x) {
mvtnorm::dmvnorm(data, init_response$mean, init_response$cov_mtx)
})
# Initial fb
fbo <- rkHMM::forward_backward(p_state_0 = init_p, trn_mtx = t(init_trn_mtx), p_obs = init_dens)
if (store_iterations) {
em_history <- vector("list", length = maxit)
em_history[[1]] <- list(
initial_p_state = init_p,
trn_mtx = init_trn_mtx,
response = list(
init_response,
init_response
),
log_like = fbo$logLike
)
}
LL.old <- fbo$logLike
factor_density <- vector("list", length = n_states)
response_list <- vector("list", length = n_states)
converge <- FALSE
for (j in 0:maxit) {
trm <- matrix(0, n_states, n_states)
for (i in 1:n_states) {
gamma_total <- sum(fbo$gamma[-nrow(fbo$gamma), i])
if (gamma_total == 0) {
trm[i,] <- trn_mtx[ , i]
} else {
for(k in 1:n_states) {
trm[i, k] <- sum(fbo$xi[-n_steps, k, i]) / gamma_total
}
}
trn_mtx_coef <- c(0, logit_link_fun$linkfun(trm[i, ])[2])
trn_mtx[ , i] <- trn_model_dist_type$linkinv(
trn_model_x %*% trn_mtx_coef,
base = base
)
}
for (i in 1:n_states) {
if (sum(fbo$gamma[,i]) > 0) {
x <- matrix(1, n_steps)
y <- data
model_fit <- lm.wfit(x, y, w = fbo$gamma[, i])
mean <- model_fit$coefficients
sigma <- cov.wt(model_fit$residuals, fbo$gamma[, i])$cov
response_list[[i]] <- list(mean = mean, cov_mtx = sigma)
factor_density[[i]] <- mvtnorm::dmvnorm(data, model_fit$coefficients, sigma)
}
}
if (store_iterations) {
em_history[[j + 1]] <- list(
trn_mtx = trn_mtx,
response = response_list,
log_like = fbo$logLike
)
}
y <- fbo$gamma[1, , drop = FALSE]
fbo <- rkHMM::forward_backward(
p_state_0 = init_p,
trn_mtx = t(trn_mtx),
p_obs = do.call(cbind, factor_density)
)
if (fbo$logLike >= LL.old) {
converge <- (crit == "absolute" && fbo$logLike - LL.old < tol) ||
(crit == "relative" && (fbo$logLike - LL.old) / abs(LL.old) < tol)
if (converge) {
cat("converged at iteration", j, "with logLik:", fbo$logLike, "\n")
break
}
} else {
# this should not really happen...
if (j > 0 && (LL.old - fbo$logLike) > tol)
stop("likelihood decreased on iteration ", j, "with rk model")
}
LL.old <- fbo$logLike
}
if (store_iterations) {
em_history <- em_history[c(1:(j + 1))]
}
if (converge) {
message <- switch(
crit,
relative = "Log likelihood converged to within tol. (relative change)",
absolute = "Log likelihood converged to within tol. (absolute change)"
)
} else
message <- "'maxit' iterations reached in EM without convergence."
list(
y = y,
trn_mtx = trn_mtx,
response = response_list
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.