mixdir_vi <- function(X, n_latent, alpha, beta, categories, max_iter, epsilon,
                      omega_init=NULL, zeta_init=NULL, phi_init=NULL, verbose=FALSE){

  # Initialize the parameters
  n_ind <- nrow(X)
  n_quest <- ncol(X)
  n_cat <- max(X, na.rm=TRUE)

  if(is.null(alpha) || length(alpha) == 0){
    alpha <- 1
  }else if(length(alpha) == 1){
    alpha <- alpha[1]
    warning(paste0("alpha should only be a single value. Using alpha=alpha[1]=", alpha[1]))
    alpha <- alpha[1]

  if(is.null(beta) || length(beta) == 0){
    beta <- 0.1
  }else if(length(beta) == 1){
    beta <- beta[1]
    warning(paste0("beta should only be a single value. Using beta=beta[1]=", beta[1]))
    beta <- beta[1]

    omega_init <- rep(1, n_latent)
    zeta_init <- extraDistr::rdirichlet(n_ind, rep(1, n_latent))
    phi_init <- lapply(1:n_quest, function(j) lapply(1:n_latent, function(k) {
      x <- sample(1:3, size=length(categories[[j]]), replace=TRUE)
      # names(x) <- categories[[j]]
    for(k in 1:n_latent){
      phi_init_elem_lengths <- sapply(phi_init, function(x) length(x[[k]]))
      if(! all(phi_init_elem_lengths == sapply(categories, length))){
        stop(paste0("phi_init has the wrong number of elements for feature: ",
                    paste0(which(phi_init_elem_lengths != sapply(categories, length)), collapse = ", ")))

  omega <-omega_init
  zeta <- zeta_init
  phi <- phi_init
  phia <- conv_phi_to_array(phi, n_quest, n_latent)
  # Run the algorithm
  iter <- 1
  converged <- FALSE
  elbo_hist <- rep(NA, max_iter)
  while(iter <= max_iter & ! converged){

    # Update omega
    omega <- alpha + colSums(zeta)

    # Update zeta
    # for(k in 1:n_latent){
    #   zeta[, k] <- sapply(1:n_ind, function(i){
    #     exp(digamma(omega[k]) - digamma(sum(omega)) + sum(sapply(1:n_quest, function(j){
    #       if(is.na(X[i,j])){
    #         # If X_ij is missing ignore it
    #         0
    #       }else{
    #         digamma(phi[[j]][[k]][X[i,j]]) - digamma(sum(phi[[j]][[k]]))
    #       }
    #     })) - 1)
    #   })
    # }
    zeta <- update_zeta_cpp(zeta, X, phia, omega, n_ind, n_quest, n_latent, n_cat)
    if(any(rowSums(zeta) == 0)){
      stop(paste0("There was an underflow in the calculation of zeta. Cannot continue.\n",
                  "The problem probably came from the large number of columns in the input ",
                  "data (", (ncol(X)), "). Is it possible that you want to work on t(X)?"))
    zeta <- zeta / rowSums(zeta)

    # Update phi
    for(j in 1:n_quest){
      for(k in 1:n_latent){
        for(r in seq_along(categories[[j]])){
          phi[[j]][[k]][r] <- sum(zeta[ ,k] * (X[, j] == r), na.rm=TRUE) + beta
    phia <- conv_phi_to_array(phi, n_quest, n_latent)

    elbo <- expec_log_lambda(omega, rep(alpha, length(omega))) +
      sum(sapply(1:n_ind, function(i)expec_log_zi(zeta[i, ], omega))) +
      sum(sapply(1:n_quest, function(j) sum(sapply(1:n_latent, function(k) expec_log_ujk(phi[[j]][[k]], rep(beta, length(categories[[j]]))) )))) +
      # sum(sapply(1:n_ind, function(i) sum(sapply(1:n_quest, function(j) sum(sapply(1:n_latent, function(k)
      #   expec_log_xij(X[i,j], phi[[j]][[k]], zeta[i,k]) )))))) +
      expec_log_x_cpp(X, phia, zeta, n_quest, n_latent, n_cat) +
      entrop_omega(omega) +
      sum(sapply(1:n_ind, function(i)entrop_zeta(zeta[i, ]))) +
      sum(sapply(1:n_quest, function(j) sum(sapply(1:n_latent, function(k) entrop_phi(phi[[j]][[k]]) ))))

    if(iter != 1 && ! is.infinite(elbo) && elbo - elbo_hist[iter - 1] < - epsilon){
      warning(paste0("The ELBO decreased. This should not happen, it might be due to numerical instabilities or a bug in the code. ",
                     "Please contact the maintainer to report this.\n"))
    if(iter != 1 && ! is.infinite(elbo) && elbo - elbo_hist[iter - 1] < epsilon) converged <- TRUE

    if(verbose && iter %% 10 == 0) message(paste0("Iter: ", iter, " ELBO: ", formatC(elbo, digits=4)))

    elbo_hist[iter] <- elbo
    iter <- iter + 1

  elbo_hist <- elbo_hist[! is.na(elbo_hist)]

  U <- lapply(1:n_quest, function(j) lapply(1:n_latent, function(k) {
    x <- rep(NA, times=length(categories[[j]]))
    names(x) <- categories[[j]]
  names(U) <- colnames(X)
  for(j in 1:n_quest){
    for(k in 1:n_latent){
      U[[j]][[k]] <- (phi[[j]][[k]]) / sum(phi[[j]][[k]])
      names(U[[j]][[k]]) <- categories[[j]]

  omega <- alpha + colSums(zeta)
  lambda <- omega/sum(omega)

  # Bring the zeta and U to a consistent state (lambda is slightly off)
  prob_z <- matrix(vapply(seq_along(lambda), function(k){
    lambda[k] * exp(rowSums(log(matrix(vapply(colnames(X), function(j){
      ifelse(is.na(X[ ,j]), 1, U[[j]][[k]][X[, j]])
    }, FUN.VALUE=rep(0.0, times=n_ind)), nrow=n_ind))))
  }, FUN.VALUE=rep(0.0, times=n_ind)), nrow=n_ind)
  prob_z <- prob_z / rowSums(prob_z)

    pred_class=apply(prob_z, 1, which.max),

# Helper Functions for VI

expec_log_lambda <- function(omega, alpha) {
  lgamma(sum(alpha)) - sum(lgamma(alpha)) + sum(sapply(1:length(alpha), function(k)(alpha[k] - 1) * (digamma(omega[k]) - digamma(sum(omega)))))

expec_log_zi <- function(zeta_i, omega){
  sum(zeta_i * (digamma(omega) - digamma(sum(omega))))

expec_log_ujk <- function(phi_jk, beta){
  lgamma(sum(beta)) - sum(lgamma(beta)) + sum(sapply(1:length(beta), function(k)(beta[k] - 1) * (digamma(phi_jk[k]) - digamma(sum(phi_jk)))))

entrop_omega <- function(omega){
  - lgamma(sum(omega)) + sum(lgamma(omega)) - sum(sapply(1:length(omega), function(k)(omega[k] - 1) * (digamma(omega[k]) - digamma(sum(omega)))))

entrop_zeta <- function(zeta_i){
  zeta_i[zeta_i <= .Machine$double.xmin] <- .Machine$double.xmin
  - sum(zeta_i * log(zeta_i))

entrop_phi <- function(phi_jk){

entrop_phi <- function(phi_jk){
  - lgamma(sum(phi_jk)) + sum(lgamma(phi_jk)) - sum(sapply(1:length(phi_jk), function(k)(phi_jk[k] - 1) * (digamma(phi_jk[k]) - digamma(sum(phi_jk)))))

conv_phi_to_array <- function(phi, n_quest, n_latent){
  cat_length <- sapply(phi, function(phi_j) length(phi_j[[1]]))
  phia <- array(NA_real_, dim = c(n_quest, n_latent, max(cat_length)))
  for(j in 1:n_quest){
    for(k in 1:n_latent){
      for(r in 1:cat_length[j]){
        phia[j,k,r] <- phi[[j]][[k]][r]

