R/vglmer_regression.R

Defines functions vglmer

Documented in vglmer

#' Variational Inference for Hierarchical Generalized Linear Models
#'
#' This function estimates hierarchical models using mean-field variational
#' inference. \code{vglmer} accepts standard syntax used for \code{lme4}, e.g.,
#' \code{y ~ x + (x | g)}. Options are described below. Goplerud (2022; 2024)
#' provides details on the variational algorithms.
#'
#' @param formula \code{lme4} style-formula for random effects. Typically,
#'   \code{(1 + z | g)} indicates a random effect for each level of variable
#'   \code{"g"} with a differing slope for the effect of variable \code{"z"} and
#'   an intercept (\code{1}); see "Details" for further discussion and how to
#'   incorporate splines.
#' @param data \code{data.frame} containing the outcome and predictors.
#' @param family Options are "binomial", "linear", or "negbin" (experimental).
#'   If "binomial", outcome must be either binary (\eqn{\{0,1\}}) or
#'   \code{cbind(success, failure)} as per standard \code{glm(er)} syntax.
#'   Non-integer values are permitted for binomial if \code{force_whole} is set
#'   to \code{FALSE} in \code{vglmer_control}.
#' @param control Adjust internal options for estimation. Must use an object
#'   created by \link{vglmer_control}.
#'
#' @examples
#' 
#' set.seed(234)
#' sim_data <- data.frame(
#'   x = rnorm(100),
#'   y = rbinom(100, 1, 0.5),
#'   g = sample(letters, 100, replace = TRUE)
#' )
#'
#' # Run with defaults
#' est_vglmer <- vglmer(y ~ x + (x | g), data = sim_data, family = "binomial")
#'
#' # Simple prediction
#' predict(est_vglmer, newdata = sim_data)
#'
#' # Summarize results
#' summary(est_vglmer)
#'
#' # Extract parameters
#' coef(est_vglmer); vcov(est_vglmer)
#'
#' # Comparability with lme4,
#' # although ranef is formatted differently.
#' ranef(est_vglmer); fixef(est_vglmer)
#'
#' \donttest{
#' # Run with weaker (i.e. better) approximation
#' vglmer(y ~ x + (x | g),
#'   data = sim_data,
#'   control = vglmer_control(factorization_method = "weak"),
#'   family = "binomial")
#' }
#' 
#' \donttest{
#' # Use a spline on x with a linear outcome
#' vglmer(y ~ v_s(x),
#'   data = sim_data,
#'   family = "linear")
#' }
#' 
#' @details
#' 
#' \bold{Estimation Syntax:} The \code{formula} argument takes syntax designed
#' to be a similar as possible to \code{lme4}. That is, one can specify models
#' using \code{y ~ x + (1 | g)} where \code{(1 | g)} indicates a random intercept. While
#' not tested extensively, terms of \code{(1 | g / f)} should work as expected. Terms
#' of \code{(1 + x || g)} may work, although will raise a warning about duplicated
#' names of random effects. \code{(1 + x || g)} terms may not work with spline
#' estimation. To get around this, one can might copy the column \code{g} to
#' \code{g_copy} and then write \code{(1 | g) + (0 + x | g_copy)}.
#' 
#' \bold{Splines:} Splines can be added using the term \code{v_s(x)} for a
#' spline on the variable \code{x}. These are transformed into hierarchical
#' terms in a standard fashion (e.g. Ruppert et al. 2003) and then estimated
#' using the variational algorithms. At the present, only truncated linear
#' functions (\code{type = "tpf"}; the default) and O'Sullivan splines (Wand and
#' Ormerod 2008) are included. The options are described in more detail at
#' \link{v_s}.
#'
#' It is possible to have the spline vary across some categorical predictor by
#' specifying the \code{"by"} argument such as \code{v_s(x, by = g)}. In effect,
#' this adds additional hierarchical terms for the group-level deviations from
#' the "global" spline. \emph{Note:} In contrast to the typical presentation of
#' these splines interacted with categorical variables (e.g., Ruppert et al.
#' 2003), the default use of \code{"by"} includes the lower order interactions
#' that are regularized, i.e. \code{(1 + x | g)}, versus their unregularized
#' version (e.g., \code{x * g}); this can be changed using the \code{by_re}
#' argument described in \link{v_s}. Further, all group-level deviations from
#' the global spline share the same smoothing parameter (same prior
#' distribution).
#' 
#' \bold{Default Settings:} By default, the model is estimated using the
#' "strong" (i.e. fully factorized) variational assumption. Setting
#' \code{vglmer_control(factorization_method = "weak")} will improve the quality
#' of the variance approximation but may take considerably more time to
#' estimate. See Goplerud (2022) for discussion. 
#' 
#' By default, the prior on each random effect variance (\eqn{\Sigma_j}) uses a Huang-Wand prior (Huang
#' and Wand 2013) with hyper-parameters \eqn{\nu_j = 2} and \eqn{A_{j,k} = 5}.
#' This is designed to be proper but weakly informative. Other options are
#' discussed in \link{vglmer_control} under the \code{prior_variance} argument.
#' 
#' By default, estimation is accelerated using SQUAREM (Varadhan and Roland
#' 2008) and (one-step-late) parameter expansion for variational Bayes. Under
#' the default \code{"strong"} factorization, a "translation" expansion is used;
#' under other factorizations a "mean" expansion is used. These can be adjusted
#' using \link{vglmer_control}. See Goplerud (2024) for more discussion of
#' these methods.
#' 
#' @return This returns an object of class \code{vglmer}. The available methods
#'   (e.g. \code{coef}) can be found using \code{methods(class="vglmer")}.
#' \describe{
#' \item{beta}{Contains the estimated distribution of the fixed effects
#' (\eqn{\beta}). It is multivariate normal. \code{mean} contains the means;
#' \code{var} contains the variance matrix; \code{decomp_var} contains a matrix
#' \eqn{L} such that \eqn{L^T L} equals the full variance matrix.}
#' \item{alpha}{Contains the estimated distribution of the random effects
#' (\eqn{\alpha}). They are all multivariate normal. \code{mean} contains the
#' means; \code{dia.var} contains the variance of each random effect. \code{var}
#' contains the variance matrix of each random effect (j,g). \code{decomp_var}
#' contains a matrix \eqn{L} such that \eqn{L^T L} equals the full variance of
#' the entire set of random effects.}
#' \item{joint}{If \code{factorization_method="weak"}, this is a list with one
#' element (\code{decomp_var}) that contains a matrix \eqn{L} such that \eqn{L^T
#' L} equals the full variance matrix between the fixed and random effects
#' \eqn{q(\beta,\alpha)}. The marginal variances are included in \code{beta} and
#' \code{alpha}. If the factorization method is not \code{"weak"}, this is
#' \code{NULL}.}
#' \item{sigma}{Contains the estimated distribution of each random
#' effect covariance \eqn{\Sigma_j}; all distributions are Inverse-Wishart.
#' \code{cov} contains a list of the estimated scale matrices. \code{df}
#' contains a list of the degrees of freedom.}
#' \item{hw}{If a Huang-Wand prior is used (see Huang and Wand 2013 or Goplerud
#' 2024 for more details), then the estimated distribution. Otherwise, it is
#' \code{NULL}. All distributions are Inverse-Gamma. \code{a} contains a list of
#' the scale parameters. \code{b} contains a list of the shape parameters.}
#' \item{sigmasq}{If \code{family="linear"}, this contains a list of the
#' estimated parameters for \eqn{\sigma^2}; its distribution is Inverse-Gamma.
#' \code{a} contains the scale parameter; \code{b} contains the shape
#' parameter.}
#' \item{ln_r}{If \code{family="negbin"}, this contains the variational
#' parameters for the log dispersion parameter \eqn{\ln(r)}. \code{mu} contains
#' the mean; \code{sigma} contains the variance.}
#' \item{family}{Family of outcome.}
#' \item{ELBO}{Contains the ELBO at the termination of the algorithm.}
#' \item{ELBO_trajectory}{\code{data.frame} tracking the ELBO per iteration.}
#' \item{control}{Contains the control parameters from \code{vglmer_control}
#' used in estimation.}
#' \item{internal_parameters}{Variety of internal parameters used in
#' post-estimation functions.}
#' \item{formula}{Contains the formula used for estimation; contains the
#' original formula, fixed effects, and random effects parts separately for
#' post-estimation functions. See \code{formula.vglmer} for more details.}
#' }
#' @importFrom lme4 mkReTrms findbars subbars
#' @importFrom stats model.response model.matrix model.frame rnorm rWishart
#'   qlogis optim residuals lm plogis setNames .getXlevels
#' @importFrom graphics plot
#' @importFrom Rcpp sourceCpp
#' @importFrom mgcv interpret.gam
#' @references
#' Goplerud, Max. 2022. "Fast and Accurate Estimation of Non-Nested Binomial
#' Hierarchical Models Using Variational Inference." \emph{Bayesian Analysis}. 17(2):
#' 623-650.
#' 
#' Goplerud, Max. 2024. "Re-Evaluating Machine Learning for MRP Given the
#' Comparable Performance of (Deep) Hierarchical Models." \emph{American
#' Political Science Review}. 118(1): 529-536.
#'
#' Huang, Alan, and Matthew P. Wand. 2013. "Simple Marginally Noninformative
#' Prior Distributions for Covariance Matrices." \emph{Bayesian Analysis}.
#' 8(2):439-452.
#' 
#' Ruppert, David, Matt P. Wand, and Raymond J. Carroll. 2003.
#' \emph{Semiparametric Regression}. Cambridge University Press.
#' 
#' Varadhan, Ravi, and Christophe Roland. 2008. "Simple and Globally Convergent
#' Methods for Accelerating the Convergence of any EM Algorithm." \emph{Scandinavian
#' Journal of Statistics}. 35(2): 335-353.
#' 
#' Wand, Matt P. and Ormerod, John T. 2008. "On Semiparametric Regression with
#' O'Sullivan Penalized Splines". \emph{Australian & New Zealand Journal of Statistics}.
#' 50(2): 179-198.
#' 
#' @useDynLib vglmer
#' @export
vglmer <- function(formula, data, family, control = vglmer_control()) {

  # Verify integrity of parameter arguments
  family <- match.arg(family, choices = c("negbin", "binomial", "linear"))
  if (family == "negbin" & !(control$parameter_expansion %in% c('none', 'mean'))){
    message('Setting parameter_expansion to mean for negative binomial estimation')
    control$parameter_expansion <- 'mean'
  }
  checkdf <- inherits(data, 'data.frame')
  if (is.null(data)){
    checkdf <- TRUE
  }
  if (checkdf != TRUE) {
    warning(paste0("data is not a data.frame? Behavior may be unexpected: ", checkdf))
  }
  if (!inherits(formula, 'formula')){
    stop('"formula" must be a formula.')
  }
  # Delete the missing data
  # (i.e. sub out the random effects, do model.frame)
  #
  nobs_init <- nrow(data)

  # Interpret gam using mgcv::interpret.gam
  parse_formula <- tryCatch(
    interpret.gam(subbars(formula), extra.special = 'v_s'), error = function(e){NULL})
  if (is.null(parse_formula)){
    # If this fails, either when (a) there is custom argument in environment or (b)
    # an argument with "xt" and (1|g) is given to subbars
    parse_formula <- fallback_interpret.gam0(fallback_subbars(formula), extra.special = 'v_s')
  }
  
  if (any(!sapply(parse_formula$smooth.spec, inherits, what = 'vglmer_spline'))){
    stop('gam specials are not permitted; use v_s(...) and see documentation.')
  }
  
  if (control$verify_columns){
    if (!all(parse_formula$pred.names %in% colnames(data))){
      missing_columns <- setdiff(parse_formula$pred.names, colnames(data))
      stop(
        paste0('The following columns are missing from "data". Can override with vglmer_control (not usually desirable): ', 
               paste(missing_columns, collapse =', '))
      )
    }
  }
  
  data <- model.frame(parse_formula$fake.formula, data,
                      drop.unused.levels = TRUE)
  
  tt <- terms(data)

  nobs_complete <- nrow(data)
  missing_obs <- nobs_init - nobs_complete
  if (length(missing_obs) == 0) {
    missing_obs <- "??"
  }

  #Extract the Outcome
  y <- model.response(data)
  if (is.matrix(y)){
    N <- nrow(y)
    rownames(y) <- NULL
  }else{
    N <- length(y)
    y <- as.vector(y)
    names(y) <- NULL
  }
  
  
  if (!inherits(control, "vglmer_control")) {
    stop("control must be object from vglmer_control().")
  }

  do_timing <- control$do_timing
  factorization_method <- control$factorization_method
  print_prog <- control$print_prog
  iterations <- control$iterations
  quiet <- control$quiet
  parameter_expansion <- control$parameter_expansion
  tolerance_elbo <- control$tolerance_elbo
  tolerance_parameters <- control$tolerance_parameters
  debug_param <- control$debug_param
  linpred_method <- control$linpred_method
  vi_r_method <- control$vi_r_method
  if (is.numeric(vi_r_method)){
    if (length(vi_r_method) > 1){stop('If "vi_r_method" is numeric, it must be a single number.')}
    vi_r_val <- as.numeric(vi_r_method)
    vi_r_method <- "fixed"
  }else{
    vi_r_val <- NA
  }
  debug_ELBO <- control$debug_ELBO
  # Flip given that "tictoc" accepts "quiet=quiet_time"
  quiet_time <- !control$verbose_time

  if (do_timing) {
    if (!requireNamespace("tictoc", quietly = TRUE)) {
      stop("tictoc must be installed to do timing")
    }
    tic <- tictoc::tic
    toc <- tictoc::toc
    tic.clear <- tictoc::tic.clear
    tic.clearlog <- tictoc::tic.clearlog

    tic.clear()
    tic.clearlog()
    tic("Prepare Model")
  }
  if (!(factorization_method %in% c("weak", "strong", "partial", "collapsed"))) {
    stop("factorization_method must be 'weak', 'strong', or 'partial'.")
  }
  if (is.null(print_prog)) {
    print_prog <- max(c(1, floor(iterations / 20)))
  }
  if (!(family %in% c("binomial", "negbin", "linear"))) {
    stop('family must be one of "linear", "binomial", "negbin".')
  }
  
  if (family == "binomial") {
    if (is.matrix(y)) {
      # if (!(class(y) %in% c('numeric', 'integer'))){
      if (min(y) < 0) {
        stop("Negative numbers not permitted in outcome")
      }
      is.wholenumber <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
      if (any(is.wholenumber(y) == FALSE)) {
        if (control$force_whole) {
          stop("If force_whole = TRUE, must provide whole numbers as outcome")
        } else {
          warning("Non-integer numbers in y")
        }
      }
      # Total trials (Success + Failure)
      trials <- rowSums(y)
      rownames(trials) <- NULL
      # Successes
      y <- y[, 1]
      rownames(y) <- NULL
    } else {
      if (!all(y %in% c(0, 1)) & family == "binomial") {
        stop("Only {0,1} outcomes permitted for numeric y.")
      }
      trials <- rep(1, length(y))
    }
  } else if (family == 'negbin') {
    
    if (is.matrix(y)) {
      stop('"linear" family requires a vector outcome.')
    }
    
    if (!(class(y) %in% c("numeric", "integer"))) {
      stop("Must provide vector of numbers with negbin.")
    }

    if (min(y) < 0) {
      stop("Negative numbers not permitted in outcome")
    }

    is.wholenumber <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
    if (any(is.wholenumber(y) == FALSE)) {
      if (control$force_whole) {
        stop("If force_whole = TRUE, must provide whole numbers")
      } else {
        warning("Non-integer numbers in y")
      }
    }
  } else if (family == 'linear') {
    
    if (is.matrix(y)) {
      stop('"linear" family requires a vector outcome.')
    }
    if (!(class(y) %in% c("numeric", "integer"))) {
      stop("Must provide vector of numbers with linear.")
    }
    
    y <- as.numeric(y)
    
    #Do nothing if linear
  } else {
    stop('family is invalid.')
  }

  if (family %in% c("binomial", "linear")) {
    ELBO_type <- "augmented"
  } else if (family == "negbin") {
    ELBO_type <- "profiled"
  } else {
    stop("Check ELBO_type")
  }

  # Extract X (FE design matrix)
  fe_fmla <- tryCatch(
    interpret.gam(nobars(formula), extra.special = 'v_s'), error = function(e){NULL})
  if (is.null(fe_fmla)){
    # If this fails, usually when there is custom argument in environment, use this instead
    fe_fmla <- fallback_interpret.gam0(nobars(formula), extra.special = 'v_s')
  }
  
  if (length(fe_fmla$smooth.spec) > 0){
    
    # Add the linear spline terms to the main effect.
    fe_update <- sapply(fe_fmla$smooth.spec, FUN=function(i){
      if (i$type %in% c('gKRLS', 'randwalk')){
        #  Do *not* add linear terms if gKRLS or randwalk
        fe_i <- NULL
        return(fe_i)
      }else{
        if (i$by != "NA" & i$by_re == FALSE){
          fe_i <- paste0(i$term, ' * ', i$by)
        }else{
          fe_i <- i$term
        }
        return(fe_i)
      }
    })
    fe_update <- unique(unlist(fe_update))
    if (!is.null(fe_update)){
      fe_update <- paste0(fe_update, collapse = ' + ')
      fe_fmla <- update.formula(fe_fmla$pf,
                                paste0('. ~ . + 1 + ', fe_update)
      )
    }else{
      fe_fmla <- fe_fmla$pf
    }

  }else{
    fe_fmla <- fe_fmla$pf
  }
  
  # Create the FE design
  X <- model.matrix(fe_fmla, data = data)
  fe_terms <- terms(fe_fmla)
  fe_Xlevels <- .getXlevels(fe_terms, data)
  fe_contrasts <- attr(X, 'contrasts')
  
  # Extract the Z (Random Effect) design matrix.
  re_fmla <- findbars(formula)
  
  # If using splines by group, add random effects to
  # the main level.
  if (!all(sapply(parse_formula$smooth.spec, 
      FUN=function(i){i$by}) %in% c('NA'))){
    
    by_splines <- parse_formula$smooth.spec[
      which(sapply(parse_formula$smooth.spec, FUN=function(i){(i$by != "NA" & i$by_re == TRUE)}))
    ]
    
    character_re <- lapply(re_fmla, FUN=function(i){strsplit(deparse(i), split = ' \\| ')[[1]]})
    character_re_group <- sapply(character_re, FUN=function(i){i[2]})
    
    if (any(duplicated(character_re_group))){
      stop('Some grouping factors for random effects are duplicated. Reformulate initial formula.')
    }
    
    for (v in sapply(character_re, FUN=function(i){i[2]})){
      if (!(is.factor(data[[v]]) | is.character(data[[v]]))){
        data[[v]] <- as.character(data[[v]])
      }
    } 
    
    for (b in by_splines){
      
      b_term <- b$term
      b_by <- b$by
      
      if (!(is.factor(data[[b_by]]) | is.character(data[[b_by]]))){
       stop('For now, all v_s spline "by" factors must be characters or factors.') 
      }

      if (!(b$type %in% c('gKRLS', 'randwalk'))){
        # If "by" grouping already used, then add to the RE
        if (b_by %in% character_re_group){
          position_b_by <- which(b_by == character_re_group)
          existing_re_b_by <- character_re[[position_b_by]][1]
          new_re_b_by <- paste0(unique(c('1', strsplit(existing_re_b_by, split=' \\+ ')[[1]], b_term)), collapse = ' + ')
          character_re[[position_b_by]][1] <- new_re_b_by
        }else{
          # If not, then add a new RE group with a 
          # random intercept and random slope.
          character_re <- c(character_re, list(c(paste0('1 + ', b_term), b_by)))
          character_re_group <- sapply(character_re, FUN=function(i){i[2]})
        }
      }else{
        if (b_by %in% character_re_group){
          # Do Nothing
        }else{
          # If not, then add a new RE group with a 
          # random intercept and random slope.
          character_re <- c(character_re, list(c('1', b_by)))
          character_re_group <- sapply(character_re, FUN=function(i){i[2]})
        }
      }
    }
    
    character_re_fmla <- paste(sapply(character_re, FUN=function(i){paste0('(', i[1], ' | ', i[2], ' )')}), collapse = " + ")
    
    old_re <- re_fmla
    re_fmla <- lapply(character_re, FUN=function(i){str2lang(paste0(i[1], ' | ', i[2]))})
    
  }
  
  if (!is.null(re_fmla) & (length(re_fmla) > 0)){
    mk_Z <- mkReTrms(re_fmla, data, reorder.terms = FALSE, reorder.vars = FALSE)
    Z <- t(mk_Z$Zt)
    
    p.X <- ncol(X)
    p.Z <- ncol(Z)
    
    ####
    # Process the REs to get various useful terms.
    ####
    # RE names and names of variables included for each.
    names_of_RE <- mk_Z$cnms
    
    if (anyDuplicated(names(names_of_RE)) > 0){
      warning('Some random effects names are duplicated. Re-naming for stability by adding "-[0-9]" at end.')
      nre <- names(names_of_RE)
      unre <- unique(nre)
      for (u in unre){
        nre_u <- which(nre == u)
        if (length(nre_u) > 1){
          nre[nre_u] <- paste0(nre[nre_u], '-', seq_len(length(nre_u)))
        }
      }
      names(names_of_RE) <- nre
      if (anyDuplicated(names(names_of_RE)) > 0){
        stop('Renaming duplicates failed. Please rename random effects to proceed.')
      }
    }

    number_of_RE <- length(mk_Z$Gp) - 1
    
    if ( (number_of_RE < 1) & (length(parse_formula$smooth.spec) == 0) ) {
      stop("Need to provide at least one random effect or spline...")
    }
    # The position that demarcates each random effect.
    # That is, breaks_for_RE[2] means at that position + 1 does RE2 start.
    breaks_for_RE <- c(0, cumsum(diff(mk_Z$Gp)))
    # Dimensionality of \alpha_{j,g}, i.e. 1 if random intercept
    # 2 if random intercept + random slope
    d_j <- lengths(names_of_RE)
    # Number of GROUPs for each random effect.
    g_j <- diff(mk_Z$Gp) / d_j
    
    # Empty vector to build the formatted names for each random effect.
    fmt_names_Z <- c()
    init_Z_names <- colnames(Z)
    for (v in 1:number_of_RE) {
      name_of_effects_v <- names_of_RE[[v]]
      
      mod_name <- rep(name_of_effects_v, g_j[v])
      
      levels_of_re <- init_Z_names[(1 + breaks_for_RE[v]):breaks_for_RE[v + 1]]
      
      fmt_names_Z <- c(fmt_names_Z, paste0(names(names_of_RE)[v], " @ ", mod_name, " @ ", levels_of_re))
    }
    colnames(Z) <- fmt_names_Z
    cyclical_pos <- lapply(1:number_of_RE, FUN = function(i) {
      seq(breaks_for_RE[i] + 1, breaks_for_RE[i + 1])
    })
  }else{
    
    Z <- drop0(Matrix(nrow = nrow(X), ncol = 0))
    p.X <- ncol(X)
    p.Z <- 0
    names_of_RE <- c()
    number_of_RE <- 0
    breaks_for_RE <- c(0)
    d_j <- c()
    g_j <- c()
    fmt_names_Z <- c()
    cyclical_pos <- list()
    
    if ( (length(parse_formula$smooth.spec) == 0) ) {
      stop("Need to provide at least one random effect or spline...")
    }

  }
  
  M.names <- cbind(unlist(mapply(names_of_RE, g_j, SIMPLIFY = FALSE, FUN = function(i, j) {
    rep(i, j)
  })))
  
  if (!is.null(M.names)){
    U_names <- unique(cbind(rep(names(names_of_RE), g_j * d_j), M.names))
    B_j <- lapply(split(U_names[,2], U_names[,1]), FUN=function(j){
      B_jj <- which(j %in% colnames(X))
      if (length(B_jj) == 0){
        return(matrix(nrow = length(j), ncol = 0))
      }
      sparseMatrix(i = B_jj, j = 1:length(B_jj), 
        x = 1, dims = c(length(j), length(B_jj)))   
    })
    U_names_Bj <- unique(U_names[,2])
    M <- cbind(match(M.names[, 1], U_names_Bj), rep(1 / g_j, d_j * g_j))
    M <- sparseMatrix(i = 1:nrow(M), j = M[, 1], x = M[, 2], dims = c(ncol(Z), length(U_names_Bj)))
  }else{
    B_j <- list()
    M <- drop0(matrix(0, nrow = 0, ncol = ncol(X)))
  }

  if (!is.null(names_of_RE)){
    any_Mprime <- TRUE
    M_prime.names <- paste0(rep(names(names_of_RE), g_j * d_j), " @ ", M.names)
    M_prime <- cbind(match(M_prime.names, unique(M_prime.names)), rep(1 / g_j, d_j * g_j))
    M_prime <- sparseMatrix(i = seq_len(ncol(Z)),
                            j = M_prime[, 1], 
                            x = M_prime[, 2],
                            dims = c(ncol(Z), max(M_prime[,1])))
    colnames(M_prime) <- unique(M_prime.names)
    
    M_prime_one <- M_prime
    M_prime_one@x <- rep(1, length(M_prime_one@x))
    
    stopifnot(identical(paste0(rep(names(names_of_RE), d_j), " @ ", unlist(names_of_RE)), colnames(M_prime)))
    
    mu_to_beta_names <- match(unlist(names_of_RE), colnames(X))
    
    id_mu_to_beta <- seq_len(sum(d_j))
    which_is_na_mu_to_beta <- which(is.na(mu_to_beta_names))
    if (length(which_is_na_mu_to_beta) > 0){
      mu_to_beta_names <- mu_to_beta_names[-which_is_na_mu_to_beta]
      id_mu_to_beta <- id_mu_to_beta[-which_is_na_mu_to_beta]
    }
    
    M_mu_to_beta <- sparseMatrix(
      i = id_mu_to_beta, j = mu_to_beta_names, 
      x = 1, dims = c(sum(d_j), p.X))
    
  }else{
    any_Mprime <- FALSE
    M_prime_one <- M_prime <- drop0(matrix(0, nrow = 0, ncol = 0))
    M_mu_to_beta <- drop0(matrix(0, nrow = 0, ncol = p.X))
  }
  
  colnames(M_mu_to_beta) <- colnames(X)
  rownames(M_mu_to_beta) <- colnames(M_prime)
  
  
  # Extract the Specials
  if (length(parse_formula$smooth.spec) > 0){
    
    any_Mprime <- TRUE
    
    base_specials <- length(parse_formula$smooth.spec)
    
    # Number of splines + one for each "by"...
    
    n.specials <- base_specials +
      sum(sapply(parse_formula$smooth.spec, FUN=function(i){i$by}) != "NA")
    
    Z.spline.attr <- as.list(rep(NA, base_specials))

    Z.spline <- as.list(rep(NA, n.specials))
    Z.spline.size <- rep(NA, n.specials)
    special_counter <- 0
    
    for (i in 1:base_specials){
      
      special_i <- parse_formula$smooth.spec[[i]]

      if (special_i$type %in% c('gKRLS', 'randwalk')){
        all_splines_i <- data[special_i$term]
        special_i$fmt_term <- paste0('(', paste0(special_i$term, collapse=','), ')')
      }else{
        all_splines_i <- data[[special_i$term]]
        special_i$fmt_term <- special_i$term
      }
      all_splines_i <- vglmer_build_spline(x = all_splines_i, 
          by = data[[special_i$by]],
          knots = special_i$knots, type = special_i$type, 
          force_vector = special_i$force_vector,
          xt = special_i$xt,
          outer_okay = special_i$outer_okay, by_re = special_i$by_re)
      
      Z.spline.attr[[i]] <- c(all_splines_i[[1]]$attr, 
        list(type = special_i$type, by = special_i$by))
      spline_counter <- 1
      for (spline_i in all_splines_i){

        special_counter <- special_counter + 1
        
        stopifnot(spline_counter %in% 1:2)
        
        colnames(spline_i$x) <- paste0('spline @ ', special_i$fmt_term, ' @ ', colnames(spline_i$x))
        
        if (spline_counter > 1){
          spline_name <- paste0('spline-',special_i$fmt_term,'-', i, '-int')
        }else{
          spline_name <- paste0('spline-', special_i$fmt_term, '-', i, '-base')
        }
        
        Z.spline[[special_counter]] <- spline_i$x
        Z.spline.size[special_counter] <- ncol(spline_i$x)
        
        names_of_RE[[spline_name]] <- spline_name
        number_of_RE <- number_of_RE + 1
        d_j <- setNames(c(d_j, 1), c(names(d_j), spline_name))
        g_j <- setNames(c(g_j, ncol(spline_i$x)), c(names(g_j), spline_name))
        breaks_for_RE <- c(breaks_for_RE, max(breaks_for_RE) + ncol(spline_i$x))
        fmt_names_Z <- c(fmt_names_Z, colnames(spline_i$x))
        p.Z <- p.Z + ncol(spline_i$x)
        spline_counter <- spline_counter + 1
      }
    }
    
    cyclical_pos <- lapply(1:number_of_RE, FUN = function(i) {
      seq(breaks_for_RE[i] + 1, breaks_for_RE[i + 1])
    })

    Z.spline <- drop0(do.call('cbind', Z.spline))
    Z <- drop0(cbind(Z, Z.spline))
    
    if (ncol(M_prime) == 0){
      M_prime <- rbind(M_prime, 
          drop0(matrix(0, nrow = ncol(Z.spline), ncol = 0)))
      M_prime_one <- rbind(M_prime_one, 
          drop0(matrix(0, nrow = ncol(Z.spline), ncol = 0)))
    }else{
      M_prime <- rbind(M_prime, 
         drop0(sparseMatrix(i = 1, j = 1, x = 0, 
         dims = c(ncol(Z.spline), ncol(M_prime)))))
      M_prime_one <- rbind(M_prime_one, 
         drop0(sparseMatrix(i = 1, j = 1, x = 0, 
         dims = c(ncol(Z.spline), ncol(M_prime_one)))))
    }
    
    M_prime <- cbind(M_prime, drop0(sparseMatrix(i = 1, j = 1, x = 0,
      dims = c(nrow(M_prime), special_counter)
    )))
    M_prime_one <- cbind(M_prime_one, drop0(sparseMatrix(i = 1, j = 1, x = 0,
      dims = c(nrow(M_prime_one), special_counter)
    )))
    
    M_mu_to_beta <- rbind(M_mu_to_beta, 
      drop0(sparseMatrix(i = 1, j = 1, x = 0, 
        dims = c(special_counter, p.X)))
    )
    
    extra_Bj <- lapply(setdiff(names(names_of_RE), names(B_j)), FUN=function(i){Diagonal(n = d_j[i])})
    names(extra_Bj) <- setdiff(names(names_of_RE), names(B_j))
    B_j <- c(B_j, extra_Bj)[names(names_of_RE)]
    
  }else{
    n.specials <- 0
    Z.spline.attr <- NULL
    Z.spline <- NULL
    Z.spline.size <- NULL
  }
  
  if (length(B_j) > 0){
    B_j <- bdiag(B_j)
  }else{B_j <- NULL}
  
  debug_px <- control$debug_px
  if (control$parameter_expansion %in% c('translation', 'diagonal')){
    px_method <- control$px_method
    px_it <- control$px_numerical_it
    opt_prior_rho <- NULL
    parsed_RE_groups <- get_RE_groups(formula = re_fmla, data = data)
  }
  
  # List of Lists
  # Outer list: one for RE
  # Inner List: One for each GROUP with its row positions.
  outer_alpha_RE_positions <- mapply(d_j, g_j, breaks_for_RE[-length(breaks_for_RE)], 
    SIMPLIFY = FALSE, FUN = function(a, b, m) {
      split(m + seq(1, a * b), rep(1:b, each = a))
  })

  if (anyDuplicated(unlist(outer_alpha_RE_positions)) != 0 | max(unlist(outer_alpha_RE_positions)) != ncol(Z)) {
    stop("Issue with creating OA positions")
  }
  ####
  # Prepare Initial Values
  ###

  vi_sigmasq_prior_a <- 0
  vi_sigmasq_prior_b <- 0
  
  vi_sigmasq_a <- vi_sigmasq_b <- 1
  
  if (family == "linear") {

    vi_sigmasq_a <- (nrow(X) + sum(d_j * g_j))/2 + vi_sigmasq_prior_a
    vi_sigmasq_b <- sum(residuals(lm(y ~ 1))^2)/2 + vi_sigmasq_prior_b
    
    s <- y
    vi_pg_b <- 1
    vi_pg_c <- NULL
    vi_r_mu <- 0
    vi_r_sigma <- 0
    vi_r_mean <- 0
    

    choose_term <- -length(y)/2 * log(2 * pi)
    
  } else if (family == "binomial") {
    
    s <- y - trials / 2
    vi_pg_b <- trials
    vi_r_mu <- 0
    vi_r_mean <- 0
    vi_r_sigma <- 0
    choose_term <- sum(lchoose(n = round(trials), k = round(y)))
    
  } else if (family == 'negbin') {
    # Initialize
    if (vi_r_method == "fixed") {
      vi_r_mu <- vi_r_val
      vi_r_mean <- exp(vi_r_mu)
      vi_r_sigma <- 0
    } else if (vi_r_method == "VEM") {
      if (!requireNamespace("MASS", quietly = TRUE)) {
        stop("Install MASS to use negbin")
      }
      vi_r_mean <- MASS::glm.nb(y ~ 1)$theta
      vi_r_mu <- log(vi_r_mean)
      vi_r_sigma <- 0
    } else if (vi_r_method %in% c("Laplace", "delta")) {
      init_r <- optim(
        par = 0, fn = VEM.PELBO.r, method = "L-BFGS", hessian = T,
        control = list(fnscale = -1), y = y, psi = rep(log(mean(y)), length(y)), zVz = 0
      )
      vi_r_mu <- init_r$par
      vi_r_sigma <- as.numeric(-1 / init_r$hessian)

      vi_r_mean <- exp(vi_r_mu + vi_r_sigma / 2)
    } else {
      stop("vi_r_method must be 'VEM' or 'fixed'.")
    }
    s <- (y - vi_r_mean) / 2
    vi_pg_b <- y + vi_r_mean

    choose_term <- -sum(lgamma(y + 1)) - sum(y) * log(2)
  }else{
    stop('family must be linear, binomial, or negative binomial.')
  }

  # Initalize variational parameters.
  # Note that we keep a sparse matrix or lowertri such that
  # t(vi_beta_decomp) %*% vi_beta_decomp = VARIANCE

  vi_beta_decomp <- Diagonal(x = rep(0, ncol(X)))
  vi_alpha_decomp <- Diagonal(x = rep(0, ncol(Z)))

  vi_sigma_alpha_nu <- g_j

  prior_variance <- control$prior_variance
  do_huangwand <- FALSE
  vi_a_APRIOR_jp <- vi_a_nu_jp <- vi_a_a_jp <- vi_a_b_jp <- NULL
  prior_sigma_alpha_nu <- prior_sigma_alpha_phi <- NULL
  
  if (prior_variance == 'hw') {
    
    do_huangwand <- TRUE
    INNER_IT <- control$hw_inner
    vi_a_nu_jp <- rep(2, length(d_j))
    names(vi_a_nu_jp) <- names(names_of_RE)
    vi_a_APRIOR_jp <- lapply(d_j, FUN=function(i){rep(5, i)})
    vi_a_a_jp <- mapply(d_j, vi_a_nu_jp, SIMPLIFY = FALSE, 
                        FUN=function(i,nu){1/2 * (nu + rep(i, i))})
    vi_a_b_jp <- lapply(vi_a_APRIOR_jp, FUN=function(i){1/i^2})
  } else if (prior_variance == "jeffreys") {
    prior_sigma_alpha_nu <- rep(0, number_of_RE)
    prior_sigma_alpha_phi <- lapply(d_j, FUN = function(i) {
      diag(x = 0, nrow = i, ncol = i)
    })
  } else if (prior_variance == "mean_exists") {
    prior_sigma_alpha_nu <- d_j + 1 # Ensures the mean exists...
    prior_sigma_alpha_phi <- lapply(d_j, FUN = function(i) {
      diag(x = 1, nrow = i, ncol = i)
    })
  } else if (prior_variance == "limit") {
    prior_sigma_alpha_nu <- d_j - 1
    prior_sigma_alpha_phi <- lapply(d_j, FUN = function(i) {
      diag(x = 0, nrow = i, ncol = i)
    })
  } else if (prior_variance == "uniform") {
    prior_sigma_alpha_nu <- -(d_j + 1)
    prior_sigma_alpha_phi <- lapply(d_j, FUN = function(i) {
      diag(x = 0, nrow = i, ncol = i)
    })
  } else {
    stop("Invalid option for prior variance provided.")
  }

  if (do_huangwand){
    iw_prior_constant <- mapply(vi_a_nu_jp, d_j,
      FUN = function(nu, d) {
        nu <- nu + d - 1
        return(- (nu * d) / 2 * log(2) - multi_lgamma(a = nu / 2, p = d))
      }
    )
    vi_sigma_alpha_nu <- vi_sigma_alpha_nu + vi_a_nu_jp + d_j - 1
  }else{
    # normalizingly constant for wishart to make ELBO have right value to compare models.
    iw_prior_constant <- mapply(prior_sigma_alpha_nu, prior_sigma_alpha_phi,
                                FUN = function(nu, Phi) {
                                  if (nu <= (ncol(Phi) - 1)) {
                                    return(0)
                                  } else {
                                    return(make_log_invwishart_constant(nu, Phi))
                                  }
                                }
    )
    
    vi_sigma_alpha_nu <- vi_sigma_alpha_nu + prior_sigma_alpha_nu
  }

  if ( control$init %in% c("EM", "EM_FE") ){
    if (family == "linear"){
      jointXZ <- cbind(X,Z)
      if (control$init == 'EM_FE'){
        EM_init <- LinRegChol(X = drop0(X),
           omega = sparseMatrix(i = 1:nrow(X), j = 1:nrow(X), x = 1),
           y = y, prior_precision = sparseMatrix(i = 1:ncol(X), j = 1:ncol(X), x = 1e-5))$mean
        # stop('Setup EM init for linear')
        # solve(Matrix::Cholesky(  t(joint.XZ) %*% sparseMatrix(i = 1:N, j = 1:N, x = pg_mean) %*% joint.XZ + EM_variance),
        #       t(joint.XZ) %*% (adj_out) )
        EM_init <- list('beta' = EM_init, 'alpha' = rep(0, ncol(Z)))
      }else{
        stop('Setup EM init')
        
        EM_init <- LinRegChol(X = jointXZ, 
                              omega = sparseMatrix(i = 1:nrow(jointXZ), j = 1:nrow(jointXZ), x = 1), 
                              y = y, prior_precision = sparseMatrix(i = 1:ncol(jointXZ), j = 1:ncol(jointXZ), x = 1/4))$mean
        EM_init <- list('beta' = EM_init[1:ncol(X)], 'alpha' = EM_init[-1:-ncol(X)])
      }
      rm(jointXZ)
    } else if (family == "negbin") {
      if (control$init == 'EM_FE'){
        EM_init <- EM_prelim_nb(X = X, Z = drop0(matrix(0, nrow = nrow(X), ncol = 0)), y = y, est_r = exp(vi_r_mu), iter = 15, ridge = 10^5)
        EM_init <- list('beta' = EM_init$beta, 'alpha' = rep(0, ncol(Z)))
      }else{
        EM_init <- EM_prelim_nb(X = X, Z = Z, y = y, est_r = exp(vi_r_mu), iter = 15, ridge = 4)
      }
    } else {
      if (control$init == 'EM_FE'){
        EM_init <- EM_prelim_logit(X = X, Z = drop0(matrix(0, nrow = nrow(X), ncol = 0)), s = s, pg_b = vi_pg_b, iter = 15, ridge = 10^5)
        EM_init <- list('beta' = EM_init$beta, 'alpha' = rep(0, ncol(Z)))
      }else{
        EM_init <- EM_prelim_logit(X = X, Z = Z, s = s, pg_b = vi_pg_b, iter = 15, ridge = 4)
      }
    }

    vi_beta_mean <- matrix(EM_init$beta)
    vi_alpha_mean <- matrix(EM_init$alpha)

    vi_sigma_alpha <- calculate_expected_outer_alpha(
      alpha_mu = vi_alpha_mean,
      L = sparseMatrix(i = 1, j = 1, x = 1e-4, dims = rep(ncol(Z), 2)),
      re_position_list = outer_alpha_RE_positions
    )
    
    if (!do_huangwand){
      vi_sigma_alpha <- mapply(vi_sigma_alpha$outer_alpha, prior_sigma_alpha_phi, SIMPLIFY = FALSE, FUN = function(i, j) {
        i + j
      })
    }else{
      
      #Update Inverse-Wishart
      vi_sigma_alpha <- mapply(vi_sigma_alpha$outer_alpha, vi_a_a_jp, 
       vi_a_b_jp, vi_a_nu_jp, SIMPLIFY = FALSE, FUN = function(i, tilde.a, tilde.b, nu) {
         i + sparseMatrix(i = seq_len(nrow(i)), j = seq_len(nrow(i)), x = 1)
       })
      
      #Update a_{j,p}
      diag_Einv_sigma <- mapply(vi_sigma_alpha, 
          vi_sigma_alpha_nu, d_j, SIMPLIFY = FALSE, FUN = function(phi, nu, d) {
            inv_phi <- solve(phi)
            sigma.inv <- nu * inv_phi
            return(diag(sigma.inv))
          })
      vi_a_b_jp <- mapply(vi_a_nu_jp, vi_a_APRIOR_jp, diag_Einv_sigma,
        SIMPLIFY = FALSE,
        FUN=function(nu, APRIOR, diag_j){
          1/APRIOR^2 + nu * diag_j
        })
    }
  } else if (control$init == "random") {
    vi_beta_mean <- rnorm(ncol(X))
    vi_alpha_mean <- rep(0, ncol(Z))

    vi_sigma_alpha <- mapply(d_j, g_j, SIMPLIFY = FALSE, FUN = function(d, g) {
      
      out <- rWishart(n = 1, df = ifelse(g >= d, g, d), Sigma = diag(d))[ , , 1]
      
      if (d == 1){
        out <- matrix(out)
      }
      
      return(out)
      
    })

  } else if (control$init == "zero") {
    vi_beta_mean <- rep(0, ncol(X))

    if (family == "binomial") {
      vi_beta_mean[1] <- qlogis(sum(y) / sum(trials))
    } else if (family == "negbin") {
      vi_beta_mean[1] <- log(mean(y))
    } else if (family == 'linear'){
      vi_beta_mean[1] <- mean(y)
    } else {
      stop('Set up init')
    }

    vi_alpha_mean <- rep(0, ncol(Z))

    vi_sigma_alpha <- mapply(d_j, g_j, SIMPLIFY = FALSE, FUN = function(d, g) {
      diag(x = 1, ncol = d, nrow = d)
    })
    # if (do_huangwand){stop('Setup init for zero')}
  } else {
    stop("Invalid initialization method")
  }

  zero_mat <- sparseMatrix(i = 1, j = 1, x = 0, dims = c(ncol(X), ncol(X)))
  zero_mat <- drop0(zero_mat)

  if (factorization_method %in% c("weak", "collapsed")) {
    vi_joint_decomp <- bdiag(vi_beta_decomp, vi_alpha_decomp)
    joint.XZ <- cbind(X, Z)
    log_det_beta_var <- log_det_alpha_var <- NULL
  } else {
    vi_joint_decomp <- NULL
    log_det_joint_var <- NULL
  }

  # Create mapping for this to allow sparse implementations.

  mapping_sigma_alpha <- make_mapping_alpha(vi_sigma_alpha)

  running_log_det_alpha_var <- rep(NA, number_of_RE)

  lagged_alpha_mean <- rep(-Inf, ncol(Z))
  lagged_beta_mean <- rep(-Inf, ncol(X))
  lagged_sigma_alpha <- vi_sigma_alpha
  if (factorization_method %in% c("weak", "collapsed")) {
    lagged_joint_decomp <- vi_joint_decomp
  } else {
    lagged_alpha_decomp <- vi_alpha_decomp
    lagged_beta_decomp <- vi_beta_decomp
  }
  lagged_vi_r_mu <- -Inf
  lagged_vi_sigmasq_a <- lagged_vi_sigmasq_b <- -Inf
  lagged_ELBO <- -Inf
  accepted_times <- NA

  skip_translate <- FALSE
  
  accepted_times <- 0
  attempted_expansion <- 0
  
  spline_REs <- grepl(names(d_j), pattern='^spline-')
  zeromat_beta <- drop0(Diagonal(x = rep(0, ncol(X))))
  stationary_rho <- do.call('c', lapply(d_j[!spline_REs], FUN=function(i){as.vector(diag(x = i))}))
  
  if (parameter_expansion %in%  c("translation", "diagonal") & any_Mprime & any(!spline_REs)) {
    
    if (do_timing){
      tic('Build PX R Terms')
    }
    
    
    
    nonspline_positions <- sort(unlist(outer_alpha_RE_positions[!spline_REs]))
    
    size_splines <- sum((d_j * g_j)[spline_REs])
    
    est_rho <- stationary_rho
    diag_rho <- which(stationary_rho == 1)
    

    # parsed_RE_groups <- get_RE_groups(formula = formula, data = data)
    # parsed_RE_groups <- parsed_RE_groups
    
    mapping_new_Z <- do.call('cbind', parsed_RE_groups$design)
    
    mapping_J <- split(1:sum(d_j[!spline_REs]^2), rep(1:length(d_j[!spline_REs]), d_j[!spline_REs]^2))
    mapping_J <- lapply(mapping_J, FUN=function(i){i-1})
    mapping_J <- sapply(mapping_J, min)

    mapping_to_re <- parsed_RE_groups$factor
    mapping_to_re <- unlist(apply(do.call('cbind', mapping_to_re), MARGIN = 1, list), recursive = F)
    # mapping_to_re <- purrr::array_branch(do.call('cbind', mapping_to_re), margin = 1)
    
    mapping_to_re <- lapply(mapping_to_re, FUN=function(i){
      mapply(outer_alpha_RE_positions[!spline_REs], i, SIMPLIFY = FALSE, 
          FUN=function(a,b){a[[b]]})
    })
    Mmap <- do.call('rbind', lapply(mapping_to_re, FUN=function(i){as.integer(sapply(i, min))}))

    start_base_Z <- cumsum(c(0,d_j[!spline_REs]))[-(number_of_RE - sum(spline_REs) +1)]
    names(start_base_Z) <- NULL

    store_re_id <- store_id <- list()
    id_range <- 1:nrow(Mmap)
    for (j in 1:(number_of_RE - sum(spline_REs))){
      store_re_id_j <- store_id_j <- list()
      if (factorization_method == 'strong'){
        loop_j <- j
      }else{
        loop_j <- 1:j
      }
      for (jprime in loop_j){
        # print(c(j, jprime))
        umap <- unique(Mmap[, c(j, jprime)])
        store_re_id_j[[jprime]] <- unlist(apply(umap, MARGIN = 1, list), recursive = F)
        # store_re_id_j[[jprime]] <- purrr::array_branch(umap, margin = 1)
        
        id_lookup <- split(id_range, paste(Mmap[,j], Mmap[,jprime]))
        id_lookup <- id_lookup[paste(umap[,1], umap[,2])]
        names(id_lookup) <- NULL
        
        # id_lookup <- lapply(1:nrow(umap), FUN=function(i){
        #   umap_r <- umap[i,]
        #   id_r <- which( (Mmap[,j] %in% umap_r[1]) & (Mmap[,jprime] %in% umap_r[2]))
        #   return(id_r)
        # })
        store_id_j[[jprime]] <- id_lookup
      }
      store_id[[j]] <- store_id_j
      store_re_id[[j]] <- store_re_id_j
    }
    store_design <- parsed_RE_groups$design
    
    rm(parsed_RE_groups, mapping_to_re)
    
    gc()
    if (do_timing){
      toc(quiet = quiet_time, log = T)
    }
  }
  store_parameter_traj <- store_vi <- store_ELBO <- data.frame()

  if (debug_param) {
    store_beta <- array(NA, dim = c(iterations, ncol(X)))
    store_alpha <- array(NA, dim = c(iterations, ncol(Z)))
    store_sigma <- array(NA, dim = c(iterations, sum(d_j^2)))
    if (do_huangwand){
      store_hw <- array(NA, dim = c(iterations, sum(d_j)))
    }
  }
  if (do_timing) {
    toc(quiet = quiet_time, log = TRUE)
    tic.clear()
  }
  ## Begin VI algorithm:
  if (!quiet) {
    message("Begin Regression")
  }
  do_SQUAREM <- control$do_SQUAREM
  if (factorization_method == 'collapsed'){
    warning('Turning off SQUAREM for "collapsed')
    do_SQUAREM <- FALSE
  }
  if (family %in% c('negbin')){
    if (do_SQUAREM){warning('Turning off SQUAREM for negbin temporarily.')}
    do_SQUAREM <- FALSE
  }
  if (family == 'negbin' & !(control$vi_r_method %in% c('VEM', 'fixed'))){
    if (do_SQUAREM){warning('Turning off SQUAREM if "negbin" and not VEM/fixed.')}
    do_SQUAREM <- FALSE
  }

  if (do_SQUAREM){
    namedList <- utils::getFromNamespace('namedList', 'lme4')
    squarem_success <- c(0, 0)
    squarem_list <- list()
    squarem_counter <- 1
  }else{
    squarem_success <- NA
  }
  if (debug_px){
    debug_PX_ELBO <- rep(NA, iterations)
  }else{
    debug_PX_ELBO <- NULL
  }
  for (it in 1:iterations) {
    
    if (it %% print_prog == 0) {
      cat(".")
    }
    ###
    ## Polya-Gamma Updates
    ###
    # Get the x_i^T Var(beta) x_i terms.
    if (do_timing) {
      tic("Update PG")
    }

    if (family %in% 'linear'){# Ignore Polya-Gamma or Similar Updates
      
      vi_pg_mean <- rep(1, nrow(X))
      diag_vi_pg_mean <- sparseMatrix(i = 1:N, j = 1:N, x = vi_pg_mean)
      
    }else{# Estimate Polya-Gamma or Similar Updates

      if (factorization_method %in% c("weak", "collapsed")) {
        # joint_quad <- rowSums( (joint.XZ %*% t(vi_joint_decomp))^2 )
        # vi_joint_decomp <<- vi_joint_decomp
        # joint.XZ <<- joint.XZ
        joint_quad <- cpp_zVz(Z = joint.XZ, V = as(vi_joint_decomp, "generalMatrix")) 
        if (family == 'negbin'){
          joint_quad <- joint_quad + vi_r_sigma
        }
        vi_pg_c <- sqrt(as.vector(X %*% vi_beta_mean + Z %*% vi_alpha_mean - vi_r_mu)^2 + joint_quad)
      } else {
        beta_quad <- rowSums((X %*% t(vi_beta_decomp))^2)
        alpha_quad <- rowSums((Z %*% t(vi_alpha_decomp))^2)
        joint_var <- beta_quad + alpha_quad
        if (family == 'negbin'){
          joint_var <- joint_var + vi_r_sigma
        }
        vi_pg_c <- sqrt(as.vector(X %*% vi_beta_mean + Z %*% vi_alpha_mean - vi_r_mu)^2 + joint_var)
      }
      vi_pg_mean <- vi_pg_b / (2 * vi_pg_c) * tanh(vi_pg_c / 2)
      fill_zero <- which(abs(vi_pg_c) < 1e-6)
      if (length(fill_zero) > 0){
        vi_pg_mean[fill_zero] <- vi_pg_b[fill_zero] / 4
      }
      diag_vi_pg_mean <- sparseMatrix(i = 1:N, j = 1:N, x = vi_pg_mean)
    }
    
    sqrt_pg_weights <- Diagonal(x = sqrt(vi_pg_mean))
    
    if (debug_ELBO & it != 1) {
      debug_ELBO.1 <- calculate_ELBO(family = family,
        ELBO_type = ELBO_type,
        factorization_method = factorization_method,
        d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
        prior_sigma_alpha_nu = prior_sigma_alpha_nu,
        iw_prior_constant = iw_prior_constant,
        X = X, Z = Z, s = s, y = y,
        vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
        vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
        vi_sigma_outer_alpha = vi_sigma_outer_alpha,
        vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
        log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
        vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
        vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
        vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
        vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
        log_det_joint_var = log_det_joint_var, vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma,
        do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
        vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
      )
    }

    if (do_timing) {
      toc(quiet = quiet_time, log = TRUE)
      tic("Prepare Sigma")
    }

    # Process Sigma_j for manipulation
    # if Sigma_{j} is InverseWishart(a,Phi)
    # Then E[Sigma^{-1}_j] = a * Phi^{-1}
    if (factorization_method == "strong") {
      cyclical_T <- TRUE
    } else {
      cyclical_T <- FALSE
    }
    inv_mapping_alpha <- mapply(vi_sigma_alpha_nu, lapply(vi_sigma_alpha, solve),
      SIMPLIFY = FALSE, FUN = function(a, b) {
        a * b
      }
    )
    inv_mapping_alpha <- make_mapping_alpha(inv_mapping_alpha)

    if (factorization_method == "collapsed"){
      cyclical_T <- TRUE
    }
    
    Tinv <- prepare_T(
      mapping = inv_mapping_alpha, levels_per_RE = g_j, num_REs = number_of_RE,
      variables_per_RE = d_j, running_per_RE = breaks_for_RE, cyclical = cyclical_T
    )

    if (!cyclical_T & factorization_method != "collapsed") {
      Tinv <- as(Tinv, "generalMatrix")
    } else {
      Tinv <- lapply(Tinv, FUN = function(i) {
        as(i, "generalMatrix")
      })
    }
    
    if (do_timing) {
      toc(quiet = quiet_time, log = T)
      tic("Update Beta")
    }
    if (factorization_method == "weak") {
      ## Update <beta, alpha> jointly
      chol.update.joint <- LinRegChol(
        X = joint.XZ, omega = diag_vi_pg_mean,
        prior_precision = bdiag(zero_mat, Tinv),
        y = s + vi_pg_mean * vi_r_mu
      )
      Pmatrix <- sparseMatrix(i = 1:ncol(joint.XZ), j = 1 + chol.update.joint$Pindex, x = 1)

      vi_joint_L_nonpermute <- drop0(solve(chol.update.joint$origL))
      vi_joint_LP <- Pmatrix
      vi_joint_decomp <- vi_joint_L_nonpermute %*% t(vi_joint_LP)

      vi_beta_mean <- Matrix(chol.update.joint$mean[1:p.X], dimnames = list(colnames(X), NULL))
      vi_alpha_mean <- Matrix(chol.update.joint$mean[-1:-p.X], dimnames = list(fmt_names_Z, NULL))

      vi_alpha_decomp <- vi_joint_decomp[, -1:-p.X, drop = F]
      vi_beta_decomp <- vi_joint_decomp[, 1:p.X, drop = F]

      log_det_joint_var <- -2 * sum(log(diag(chol.update.joint$origL)))
      if (do_SQUAREM){
        vi_joint_L_nonpermute <- vi_joint_decomp
        vi_joint_LP <- Diagonal(n = ncol(vi_joint_decomp))
      }
    } else if (factorization_method == "collapsed") {
      
      if (family != 'binomial'){stop('"collapsed" not set up.')}
      
      beta_var <- solve(t(X) %*% diag_vi_pg_mean %*% X)
      beta_hat <- beta_var %*% t(X) %*% s
      
      P <- beta_var %*% t(X) %*% diag_vi_pg_mean %*% Z
      M <- Z - X %*% P
      
      vi_alpha_mean <- solve(t(M) %*% diag_vi_pg_mean %*% M + bdiag(Tinv),
         t(M) %*% (s - diag_vi_pg_mean %*% X %*% beta_hat)
      )
      vi_beta_mean <- beta_hat - P %*% vi_alpha_mean
      
      sqrt_pg_weights <- Diagonal(x = sqrt(vi_pg_mean))
      
      for (j in 1:number_of_RE) {
        index_j <- cyclical_pos[[j]]
        M_j <- as(M[, index_j, drop = F], 'generalMatrix')
        prec_j <- crossprod(sqrt_pg_weights %*% M_j) + Tinv[[j]]
        
        chol_var_j <- solve(t(chol(prec_j)))
        running_log_det_alpha_var[j] <- 2 * sum(log(diag(chol_var_j)))
        
        vi_alpha_decomp[index_j, index_j] <- drop0(chol_var_j)
        # as(
        #   as(chol_var_j, "generalMatrix"), "TsparseMatrix"
        # )
      }
      vi_alpha_L_nonpermute <- vi_alpha_decomp
      vi_alpha_LP <- Diagonal(n = nrow(vi_alpha_L_nonpermute))
      vi_alpha_decomp <- vi_alpha_L_nonpermute  %*% t(vi_alpha_LP)
      vi_alpha_decomp <- drop0(vi_alpha_decomp)
      vi_alpha_decomp <- as(vi_alpha_decomp, 'generalMatrix')
      
      log_det_alpha_var <- sum(running_log_det_alpha_var)
      
      var_ALPHA <- t(vi_alpha_decomp) %*% vi_alpha_decomp
      vi_joint_all <- bdiag(beta_var, var_ALPHA)
      
      vi_joint_all[seq_len(nrow(beta_var)), seq_len(nrow(beta_var))] <- 
        P %*% var_ALPHA %*% t(P) + vi_joint_all[seq_len(nrow(beta_var)), seq_len(nrow(beta_var))]
      vi_joint_all[seq_len(nrow(beta_var)),-seq_len(nrow(beta_var)), drop = F] <- - P %*% var_ALPHA
      vi_joint_all[-seq_len(nrow(beta_var)),seq_len(nrow(beta_var)),drop=F] <- - t(P %*% var_ALPHA)
      
      vi_joint_decomp <- chol(vi_joint_all)

      vi_beta_decomp <- vi_joint_decomp[,1:p.X,drop=F]
      # vi_beta_decomp <- chol(beta_var)
      # vi_beta_L_nonpermute <- vi_beta_decomp
      # vi_beta_LP <- Diagonal(n = nrow(vi_beta_mean))
      # vi_joint_LP <- Diagonal(n = nrow(vi_joint_decomp))
      # vi_joint_L_nonpermute <- vi_joint_decomp
      
      log_det_joint_var <- NA
      log_det_beta_var <- as.numeric(determinant(beta_var)$modulus)
      
    } else if (factorization_method == "partial") {
      if (linpred_method == "cyclical") {
        # Do not run except as backup
        # ###Non optimized
        # precision_beta <- t(X) %*% diag_vi_pg_mean %*% X
        # nonopt_beta <- solve(precision_beta, t(X) %*% (s - diag_vi_pg_mean %*% Z %*% vi_alpha_mean))
        # precision_alpha <- t(Z) %*% diag_vi_pg_mean %*% Z + Tinv
        # nonopt_alpha <- solve(precision_alpha, t(Z) %*% (s - diag_vi_pg_mean %*% X %*% nonopt_beta))

        chol.update.beta <- LinRegChol(
          X = as(X, "sparseMatrix"), omega = diag_vi_pg_mean, prior_precision = zero_mat,
          y = as.vector(s - diag_vi_pg_mean %*% Z %*% vi_alpha_mean)
        )
        Pmatrix <- sparseMatrix(i = 1:p.X, j = 1 + chol.update.beta$Pindex, x = 1)

        # P origL oriL^T P^T = PRECISION
        # t(decompVar) %*%  decompVar = VARIANCE = (origL^{-1} t(P))^T (origL^{-1} t(P))

        vi_beta_L_nonpermute <- drop0(solve(chol.update.beta$origL))
        vi_beta_LP <- Pmatrix
        vi_beta_decomp <- vi_beta_L_nonpermute %*% t(vi_beta_LP)
        vi_beta_mean <- chol.update.beta$mean
        log_det_beta_var <- -2 * sum(log(diag(chol.update.beta$origL)))

        chol.update.alpha <- LinRegChol(
          X = Z, omega = diag_vi_pg_mean, prior_precision = Tinv,
          y = as.vector(s - diag_vi_pg_mean %*% X %*% vi_beta_mean)
        )
        Pmatrix <- sparseMatrix(i = 1:p.Z, j = 1 + chol.update.alpha$Pindex, x = 1)

        vi_alpha_L_nonpermute <- drop0(solve(chol.update.alpha$origL))
        vi_alpha_LP <- Pmatrix
        vi_alpha_decomp <- vi_alpha_L_nonpermute  %*% t(vi_alpha_LP)
        vi_alpha_decomp <- drop0(vi_alpha_decomp)
        vi_alpha_decomp <- as(vi_alpha_decomp, 'generalMatrix')
        vi_alpha_mean <- chol.update.alpha$mean
        log_det_alpha_var <- -2 * sum(log(diag(chol.update.alpha$origL)))

        vi_beta_mean <- Matrix(vi_beta_mean, dimnames = list(colnames(X), NULL))
        vi_alpha_mean <- Matrix(vi_alpha_mean, dimnames = list(fmt_names_Z, NULL))
      } else if (linpred_method == "joint") {
        joint.XZ <- cbind(X, Z)

        
        chol.update.joint <- solve(Matrix::Cholesky(  
          crossprod(Diagonal(x = sqrt(vi_pg_mean)) %*% joint.XZ) + 
            bdiag(zero_mat, bdiag(Tinv)) ),
               t(joint.XZ) %*% (s + vi_pg_mean * vi_r_mu) )
        vi_beta_mean <- Matrix(chol.update.joint[1:p.X,], dimnames = list(colnames(X), NULL))
        vi_alpha_mean <- Matrix(chol.update.joint[-1:-p.X,], dimnames = list(fmt_names_Z, NULL))
        
        # chol.update.joint <- LinRegChol(
        #   X = joint.XZ, omega = diag_vi_pg_mean, prior_precision = bdiag(zero_mat, Tinv),
        #   y = s + vi_pg_mean * vi_r_mu,
        #   save_chol = FALSE
        # )
        # vi_beta_mean <- Matrix(chol.update.joint$mean[1:p.X], dimnames = list(colnames(X), NULL))
        # vi_alpha_mean <- Matrix(chol.update.joint$mean[-1:-p.X], dimnames = list(fmt_names_Z, NULL))
        
        vi_beta_decomp <- solve(t(chol(as.matrix(t(X) %*% diag_vi_pg_mean %*% X))))
        
        vi_beta_L_nonpermute <- vi_beta_decomp
        vi_beta_LP <- Diagonal(n = ncol(vi_beta_decomp))
        
        log_det_beta_var <- 2 * sum(log(diag(vi_beta_decomp)))

        chol.update.alpha <- LinRegChol(
          X = Z, omega = diag_vi_pg_mean, prior_precision = Tinv,
          y = s + vi_pg_mean * vi_r_mu
        )
        Pmatrix <- sparseMatrix(i = 1:p.Z, j = 1 + chol.update.alpha$Pindex, x = 1)

        vi_alpha_L_nonpermute <- drop0(solve(chol.update.alpha$origL))
        vi_alpha_LP <- Pmatrix
        
        vi_alpha_decomp <- vi_alpha_L_nonpermute %*% t(vi_alpha_LP)
        vi_alpha_decomp <- drop0(vi_alpha_decomp)
        log_det_alpha_var <- -2 * sum(log(diag(chol.update.alpha$origL)))
        
        if (do_SQUAREM){
          vi_alpha_L_nonpermute <- vi_alpha_decomp
          vi_alpha_LP <- Diagonal(n = ncol(vi_alpha_decomp))
        }
        
      } else {
        stop("Invalid linpred method for partial scheme")
      }
    } else if (factorization_method == "strong") {
      running_log_det_alpha_var <- rep(NA, number_of_RE)
      vi_alpha_decomp <- sparseMatrix(i = 1, j = 1, x = 0, dims = rep(p.Z, 2))
      
      if (linpred_method == "joint") {
        if (it == 1){
          joint.XZ <- cbind(X, Z)
        }
        if (do_timing) {
          tic("ux_mean")
        }
        
        chol.update.joint <- solve(Matrix::Cholesky(  
          crossprod(sqrt_pg_weights %*% joint.XZ) + 
            bdiag(zero_mat, bdiag(Tinv)) ),
          t(joint.XZ) %*% (s + vi_pg_mean * vi_r_mu) )
        
        vi_beta_mean <- Matrix(chol.update.joint[1:p.X,], dimnames = list(colnames(X), NULL))
        vi_alpha_mean <- Matrix(chol.update.joint[-1:-p.X,], dimnames = list(fmt_names_Z, NULL))

        # chol.update.joint <- LinRegChol(X = joint.XZ,
        #   omega = diag_vi_pg_mean,
        #   prior_precision = bdiag(zero_mat, bdiag(Tinv)),
        #   y = s + vi_pg_mean * vi_r_mu,
        #   save_chol = FALSE)
        # vi_beta_mean <- Matrix(chol.update.joint$mean[1:p.X], dimnames = list(colnames(X), NULL))
        # vi_alpha_mean <- Matrix(chol.update.joint$mean[-1:-p.X], dimnames = list(fmt_names_Z, NULL))

        if (do_timing) {
          toc(quiet = quiet_time, log = T)
          tic("ux_var")
        }
        
        vi_beta_decomp <- solve(t(chol(as.matrix(t(X) %*% diag_vi_pg_mean %*% X))))
        
        vi_beta_L_nonpermute <- vi_beta_decomp
        vi_beta_LP <- Diagonal(n = nrow(vi_beta_decomp))
        log_det_beta_var <- 2 * sum(log(diag(vi_beta_decomp)))
        #-log(det(t(X) %*% diag_vi_pg_mean %*% X))

        running_log_det_alpha_var <- rep(NA, number_of_RE)


        for (j in 1:number_of_RE) {
          index_j <- cyclical_pos[[j]]
          Z_j <- Z[, index_j, drop = F]
          prec_j <- crossprod(sqrt_pg_weights %*% Z_j) + Tinv[[j]]

          chol_var_j <- solve(t(chol(prec_j)))
          running_log_det_alpha_var[j] <- 2 * sum(log(diag(chol_var_j)))

          vi_alpha_decomp[index_j, index_j] <- drop0(chol_var_j)
          # as(
          #   as(chol_var_j, "generalMatrix"), "TsparseMatrix"
          # )
          
        }
        
        vi_alpha_L_nonpermute <- vi_alpha_decomp
        vi_alpha_LP <- Diagonal(n = nrow(vi_alpha_L_nonpermute))
        
        log_det_alpha_var <- sum(running_log_det_alpha_var)
        
        if (do_timing){
          toc(quiet = quiet_time, log = T)
        }
      } else if (linpred_method == "solve_normal") {
        bind_rhs_j <- list()
        bind_lhs_j <- list()

        for (j in 1:number_of_RE) {
          index_j <- cyclical_pos[[j]]
          Z_j <- Z[, index_j, drop = F]
          Z_negj <- Z[, -index_j, drop = F]
          prec_j <- crossprod(sqrt_pg_weights %*% Z_j) + Tinv[[j]]

          chol_prec_j <- t(chol(prec_j))
          chol_var_j <- solve(chol_prec_j)

          mod_j <- solve(prec_j)

          term_j <- mod_j %*% t(Z_j) %*% diag_vi_pg_mean %*% Z
          term_j[, index_j, drop = F] <- Diagonal(n = ncol(Z_j))
          term_j <- cbind(term_j, mod_j %*% t(Z_j) %*% diag_vi_pg_mean %*% X)

          bind_lhs_j[[j]] <- term_j
          bind_rhs_j[[j]] <- mod_j %*% t(Z_j) %*% s

          running_log_det_alpha_var[j] <- 2 * sum(log(diag(chol_var_j)))
          vi_alpha_decomp[index_j, index_j] <- drop0(chol_var_j)
          # as(
          #   as(chol_var_j, "generalMatrix"), "TsparseMatrix"
          # )
        }

        log_det_alpha_var <- sum(running_log_det_alpha_var)

        bind_lhs_j <- drop0(do.call("rbind", bind_lhs_j))
        bind_rhs_j <- do.call("rbind", bind_rhs_j)

        vi_beta_decomp <- solve(t(chol(as.matrix(t(X) %*% diag_vi_pg_mean %*% X))))
        vi_beta_var <- solve(t(X) %*% diag_vi_pg_mean %*% X)
        log_det_beta_var <- 2 * sum(log(diag(vi_beta_decomp)))

        # vi_beta_mean <- vi_beta_var %*% t(X) %*% (s - diag_vi_pg_mean %*% Z %*% vi_alpha_mean)
        # vi_alpha_mean <- solve(bind_lhs_j[,1:ncol(Z)], bind_rhs_j)
        #
        # vi_alpha_mean <- Matrix(vi_alpha_mean)
        # vi_beta_mean <- Matrix(vi_beta_mean)

        bind_lhs_j <- drop0(rbind(bind_lhs_j, cbind(vi_beta_var %*% t(X) %*% diag_vi_pg_mean %*% Z, Diagonal(n = ncol(X)))))
        bind_rhs_j <- rbind(bind_rhs_j, vi_beta_var %*% t(X) %*% s)
        #
        bind_solution <- solve(bind_lhs_j) %*% bind_rhs_j
        # print(cbind(bind_solution, rbind(vi_alpha_mean, vi_beta_mean)))
        #
        vi_beta_mean <- Matrix(bind_solution[-1:-ncol(Z)], dimnames = list(colnames(X), NULL))
        vi_alpha_mean <- Matrix(bind_solution[1:ncol(Z)], dimnames = list(fmt_names_Z, NULL))
      } else if (linpred_method == "cyclical") {
        for (j in 1:number_of_RE) {
          index_j <- cyclical_pos[[j]]
          Z_j <- Z[, index_j, drop = F]
          Z_negj <- Z[, -index_j, drop = F]

          chol.j <- LinRegChol(
            X = Z_j, omega = diag_vi_pg_mean, prior_precision = Tinv[[j]],
            y = as.vector(s + vi_pg_mean * vi_r_mu - diag_vi_pg_mean %*% (X %*% vi_beta_mean + Z_negj %*% vi_alpha_mean[-index_j]))
          )
          vi_alpha_mean[index_j] <- chol.j$mean

          Pmatrix <- sparseMatrix(i = 1:ncol(Z_j), j = 1 + chol.j$Pindex, x = 1)

          running_log_det_alpha_var[j] <- -2 * sum(log(diag(chol.j$origL)))
          vi_alpha_decomp[index_j, index_j] <- solve(chol.j$origL) %*% t(Pmatrix)
        }

        vi_alpha_L_nonpermute <- vi_alpha_decomp
        vi_alpha_LP <- Diagonal(n = ncol(vi_alpha_L_nonpermute))
        # vi_alpha_decomp <- bdiag(vi_alpha_decomp)
        log_det_alpha_var <- sum(running_log_det_alpha_var)

        chol.update.beta <- LinRegChol(
          X = as(X, "sparseMatrix"), omega = diag_vi_pg_mean, prior_precision = zero_mat,
          y = as.vector(s + vi_pg_mean * vi_r_mu - diag_vi_pg_mean %*% Z %*% vi_alpha_mean)
        )
        Pmatrix <- sparseMatrix(i = 1:p.X, j = 1 + chol.update.beta$Pindex, x = 1)

        vi_beta_L_nonpermute <- drop0(solve(chol.update.beta$origL))
        vi_beta_LP <- Pmatrix
        
        vi_beta_decomp <- vi_beta_L_nonpermute %*% t(Pmatrix)
        vi_beta_mean <- chol.update.beta$mean
        log_det_beta_var <- -2 * sum(log(diag(chol.update.beta$origL)))

        vi_beta_mean <- Matrix(vi_beta_mean, dimnames = list(colnames(X), NULL))
        vi_alpha_mean <- Matrix(vi_alpha_mean, dimnames = list(fmt_names_Z, NULL))
      } else {
        stop("Invalid linpred method")
      }
    } else {
      stop("Invalid factorization method.")
    }

    if (family == 'linear'){
      
      adjust_var <- 1/sqrt(vi_sigmasq_a/vi_sigmasq_b)
      
      vi_beta_decomp <- vi_beta_decomp * adjust_var
      vi_alpha_decomp <- vi_alpha_decomp * adjust_var
      vi_joint_decomp <- vi_joint_decomp * adjust_var
      
      if (factorization_method == 'weak'){
        vi_joint_L_nonpermute <- vi_joint_L_nonpermute * adjust_var
      }else{
        vi_beta_L_nonpermute <- vi_beta_L_nonpermute * adjust_var
        vi_alpha_L_nonpermute <- vi_alpha_L_nonpermute * adjust_var
      }

      ln_sigmasq <- log(vi_sigmasq_b) - log(vi_sigmasq_a)
      log_det_joint_var <- log_det_joint_var + ncol(vi_joint_decomp) * ln_sigmasq
      log_det_beta_var <- log_det_beta_var + ncol(vi_beta_decomp) * ln_sigmasq
      log_det_alpha_var <- log_det_alpha_var + ncol(vi_alpha_decomp) * ln_sigmasq
    }

    if (debug_ELBO & it != 1) {
      
      variance_by_alpha_jg <- calculate_expected_outer_alpha(L = vi_alpha_decomp, alpha_mu = as.vector(vi_alpha_mean), re_position_list = outer_alpha_RE_positions)
      vi_sigma_outer_alpha <- variance_by_alpha_jg$outer_alpha

      debug_ELBO.2 <- calculate_ELBO(family = family,
        ELBO_type = ELBO_type,
        factorization_method = factorization_method,
        d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
        prior_sigma_alpha_nu = prior_sigma_alpha_nu,
        iw_prior_constant = iw_prior_constant,
        X = X, Z = Z, s = s, y = y,
        vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
        vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
        vi_sigma_outer_alpha = vi_sigma_outer_alpha,
        vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
        log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
        vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
        vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
        vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
        vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
        log_det_joint_var = log_det_joint_var, vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean,
        vi_r_sigma = vi_r_sigma,
        do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
        vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
      )
    }
    if (do_timing) {
      toc(quiet = quiet_time, log = T)
      tic("Update Sigma")
    }
    ###
    # Update \Sigma_j
    ##

    if (!do_huangwand){#Update standard Inverse-Wishart
      
      variance_by_alpha_jg <- calculate_expected_outer_alpha(L = vi_alpha_decomp, alpha_mu = as.vector(vi_alpha_mean), re_position_list = outer_alpha_RE_positions)
      vi_sigma_outer_alpha <- variance_by_alpha_jg$outer_alpha
      
      vi_sigma_alpha <- mapply(vi_sigma_outer_alpha, prior_sigma_alpha_phi, SIMPLIFY = FALSE, FUN = function(i, j) {
        i * vi_sigmasq_a/vi_sigmasq_b + j
      })
      
    }else{
        #Update Inverse-Wishart
        variance_by_alpha_jg <- calculate_expected_outer_alpha(L = vi_alpha_decomp, alpha_mu = as.vector(vi_alpha_mean), re_position_list = outer_alpha_RE_positions)
        vi_sigma_outer_alpha <- variance_by_alpha_jg$outer_alpha
        
        for (inner_it in 1:INNER_IT){
          
          vi_sigma_alpha <- mapply(vi_sigma_outer_alpha, vi_a_a_jp, 
           vi_a_b_jp, vi_a_nu_jp, SIMPLIFY = FALSE, 
           FUN = function(i, tilde.a, tilde.b, nu) {
             i * vi_sigmasq_a/vi_sigmasq_b + Diagonal(x = tilde.a/tilde.b) * 2 * nu
           })
          
          #Update a_{j,p}
          diag_Einv_sigma <- mapply(vi_sigma_alpha, 
              vi_sigma_alpha_nu, d_j, SIMPLIFY = FALSE, FUN = function(phi, nu, d) {
                inv_phi <- solve(phi)
                sigma.inv <- nu * inv_phi
                return(diag(sigma.inv))
              })
          vi_a_b_jp <- mapply(vi_a_nu_jp, vi_a_APRIOR_jp, diag_Einv_sigma,
                              SIMPLIFY = FALSE,
                              FUN=function(nu, APRIOR, diag_j){
                                1/APRIOR^2 + nu * diag_j
                              })
          
        }
        # d_j <<- d_j
        # vi_alpha_decomp <<- vi_alpha_decomp
        # Tinv <<- Tinv
        # vi_alpha_mean <<- vi_alpha_mean
    }

    if (do_timing) {
      toc(quiet = quiet_time, log = T)
      tic("Update Aux")
    }

    # Update the auxilary parameters
    if (family == "negbin") {
      vi_r_param <- update_r(
        vi_r_mu = vi_r_mu, vi_r_sigma = vi_r_sigma,
        y = y, X = X, Z = Z, factorization_method = factorization_method,
        vi_beta_mean = vi_beta_mean, vi_beta_decomp = vi_beta_decomp,
        vi_alpha_mean = vi_alpha_mean, vi_alpha_decomp = vi_alpha_decomp,
        vi_joint_decomp = vi_joint_decomp, vi_r_method = vi_r_method
      )

      vi_r_mu <- vi_r_param[1]
      vi_r_sigma <- vi_r_param[2]
      vi_r_mean <- exp(vi_r_mu + vi_r_sigma / 2)
      
      s <- (y - vi_r_mean) / 2
      vi_pg_b <- y + vi_r_mean
    } else if (family == 'linear') {
      
      if (factorization_method == 'weak'){
        joint_quad <- cpp_zVz(Z = joint.XZ, V = as(vi_joint_decomp, "generalMatrix"))
        vi_lp <- (s - as.vector(X %*% vi_beta_mean + Z %*% vi_alpha_mean))^2 + joint_quad
      } else{
        beta_quad <- rowSums((X %*% t(vi_beta_decomp))^2)
        alpha_quad <- rowSums((Z %*% t(vi_alpha_decomp))^2)
        vi_lp <- (s - as.vector(X %*% vi_beta_mean + Z %*% vi_alpha_mean))^2 + beta_quad + alpha_quad
      }
      
      vi_kernel <- expect_alpha_prior_kernel(vi_sigma_alpha = vi_sigma_alpha, 
          vi_sigma_alpha_nu = vi_sigma_alpha_nu, d_j = d_j,
          vi_sigma_outer_alpha = vi_sigma_outer_alpha)
      vi_sigmasq_b <- (sum(vi_lp) + vi_kernel)/2 + vi_sigmasq_prior_b
      
    }

    if (do_timing) {
      toc(quiet = quiet_time, log = T)
    }
    ### PARAMETER EXPANSIONS!
    if (debug_ELBO) {
      debug_ELBO.3 <- calculate_ELBO(family = family,
        ELBO_type = ELBO_type,
        factorization_method = factorization_method,
        d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
        prior_sigma_alpha_nu = prior_sigma_alpha_nu,
        iw_prior_constant = iw_prior_constant,
        X = X, Z = Z, s = s, y = y,
        vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
        vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
        vi_sigma_outer_alpha = vi_sigma_outer_alpha,
        vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
        log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
        vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
        vi_joint_decomp = vi_joint_decomp,
        log_det_joint_var = log_det_joint_var, 
        vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma, 
        choose_term = choose_term,
        vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
        vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
        do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
        vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
      )
    }

    if (parameter_expansion == "none" | !any_Mprime) {
      
      accept.PX <- TRUE
      
    } else {
      
      if (do_timing) {
        tic("Update PX")
      }
      
      # Do a simple mean adjusted expansion.
      # Get the mean of each random effect.

      vi_mu_j <- t(M_prime) %*% vi_alpha_mean
      
      meat_Bj <- bdiag(mapply(vi_sigma_alpha, vi_sigma_alpha_nu, d_j, 
        SIMPLIFY = FALSE, FUN = function(phi, nu, d) {
          inv_phi <- solve(phi)
          sigma.inv <- nu * inv_phi
          return(sigma.inv)
        }))
      
      proj_vi_mu_j <- B_j %*% solve(t(B_j) %*% meat_Bj %*% B_j) %*% t(B_j) %*% meat_Bj %*% vi_mu_j
      
      # Remove the "excess mean" mu_j from each random effect \alpha_{j,g}
      # and add the summd mass back to the betas.
      vi_alpha_mean <- vi_alpha_mean - M_prime_one %*% proj_vi_mu_j
      vi_beta_mean <- vi_beta_mean + t(M_mu_to_beta) %*% proj_vi_mu_j
      
      variance_by_alpha_jg <- calculate_expected_outer_alpha(
        L = vi_alpha_decomp,
        alpha_mu = as.vector(vi_alpha_mean), 
        re_position_list = outer_alpha_RE_positions
      )
      
      vi_sigma_outer_alpha <- variance_by_alpha_jg$outer_alpha
      if (parameter_expansion == "mean"){accept.PX <- TRUE}
    }  

    quiet_rho <- control$quiet_rho
    
    if (parameter_expansion %in% c("translation", "diagonal") & skip_translate == FALSE & any_Mprime) {
      
      attempted_expansion <- attempted_expansion + 1
      
      if (debug_px){

        prior.ELBO <- calculate_ELBO(family = family, ELBO_type = ELBO_type,
           factorization_method = factorization_method,
           d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
           prior_sigma_alpha_nu = prior_sigma_alpha_nu,
           iw_prior_constant = iw_prior_constant,
           X = X, Z = Z, s = s, y = y,
           vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
           vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
           vi_sigma_outer_alpha = vi_sigma_outer_alpha,
           vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
           log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
           vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
           vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
           vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
           vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
           log_det_joint_var = log_det_joint_var, 
           vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma,
           do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
           vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
        )
      }

      if (!quiet_rho){cat('r')}
      
      if (do_timing){
        tic('px_r')
      }
      
      if (any(!spline_REs)){
        raw_R <- R_ridge <- vecR_ridge_new(L = vi_alpha_decomp[,nonspline_positions], pg_mean = diag(diag_vi_pg_mean),
                                           mapping_J = mapping_J, d = d_j[!spline_REs],
                                           store_id = store_id, store_re_id = store_re_id,
                                           store_design = store_design, 
                                           diag_only = (factorization_method == 'strong'))
      }else{
        raw_R <- R_ridge <- matrix(0, ncol = 0, nrow = 0)
      }

      if (factorization_method == 'weak'){
        stop('no Translation PX for weak yet...')
      }
      
      if (!quiet_rho){cat('r')}

      if (any(!spline_REs)){
        R_design <- vecR_design(alpha_mu = as.vector(vi_alpha_mean), Z = mapping_new_Z, 
                                M = Mmap, mapping_J = mapping_J, d = d_j[!spline_REs],
                                start_z = start_base_Z)
      }else{
        R_design <- matrix(0, nrow = N, ncol = 0)
      }

      if (sum(spline_REs)){
        R_spline_design <- sapply(cyclical_pos[spline_REs], FUN=function(i){
          as.vector(Z[,i,drop=F] %*% vi_alpha_mean[i,])
        })
        
        R_spline_ridge <- sapply(cyclical_pos[spline_REs], FUN=function(s){vi_alpha_decomp[,s, drop = F]})
        R_spline_ridge <- Diagonal(x =mapply(R_spline_ridge, cyclical_pos[spline_REs], FUN=function(V, pos){
          sum(vi_pg_mean * cpp_zVz(Z = drop0(Z[,pos,drop=F]), V = as(V, 'generalMatrix')))
        }))
        # Manually convert "ddiMatrix" to "generalMatrix" so doesn't fail on
        # old versions of "Matrix" package.
        if (inherits(R_spline_ridge, 'ddiMatrix')){
          R_spline_ridge <- diag(R_spline_ridge)
          R_spline_ridge <- sparseMatrix(
            i = seq_len(length(R_spline_ridge)),
            j = seq_len(length(R_spline_ridge)),
            x = R_spline_ridge)
        }else{
          R_spline_ridge <- as(R_spline_ridge, 'generalMatrix')
        }
      }else{
        R_spline_ridge <- drop0(matrix(0, nrow = 0, ncol = 0))
        R_spline_design <- matrix(nrow = nrow(X), ncol = 0)
      }
      
      
      if (do_timing){
        toc(quiet = quiet_time, log = TRUE)
        tic('px_fit')
      }
      #If a DIAGONAL expansion, then only update the diagonal elements
      if (parameter_expansion == "diagonal"){
        stop('parameter_expansion "diagonal" turned off.')
        # XR <- cbind(X, R_spline_design, R_design[, diag_rho])
        # R_ridge <- bdiag(zeromat_beta, R_spline_ridge, R_ridge[diag_rho, diag_rho])
        # 
        # if (do_huangwand){
        #   vec_OSL_prior <- do.call('c', mapply(vi_a_APRIOR_jp[!spline_REs], 
        #                                        vi_a_a_jp[!spline_REs], 
        #                                        vi_a_b_jp[!spline_REs],
        #                                        SIMPLIFY = FALSE,
        #     FUN=function(i,a,b){1-2/i^2 * a/b}))
        #   vec_OSL_prior <- c(rep(0, p.X), OSL_spline_prior, vec_OSL_prior)
        # }else{
        #   vec_OSL_prior <- vec_OSL_prior[c(seq_len(p.X + sum(spline_REs)), p.X + sum(spline_REs) + diag_rho),,drop=F]
        # }
        # if (length(vec_OSL_prior) != ncol(XR)){stop('MISALIGNED DIMENSIONS')}
        # 
        # update_expansion_XR <- vecR_fast_ridge(X = drop0(XR), 
        #  omega = diag_vi_pg_mean, prior_precision = R_ridge, y = as.vector(s), 
        #  adjust_y = as.vector(vec_OSL_prior))
        # 
        # update_expansion_bX <- Matrix(update_expansion_XR[1:p.X])
        # update_expansion_splines <- Matrix(update_expansion_XR[-(1:p.X)][seq_len(size_splines)])
        # 
        # update_expansion_R <- mapply(split(update_expansion_XR[-seq_len(p.X + size_splines)], 
        #   rep(1:(number_of_RE - sum(spline_REs)), d_j[!spline_REs])), d_j[!spline_REs], SIMPLIFY = FALSE, 
        #   FUN=function(i,d){
        #     dg <- diag(x = d)
        #     diag(dg) <- i
        #     return(dg)
        #   })
        #  update_diag_R <- split(update_expansion_XR[-seq_len(p.X + size_splines)], 
        #                         rep(1:(number_of_RE - sum(spline_REs)), d_j[!spline_REs]))
        #  rownames(update_expansion_bX) <- colnames(X)
      }else{
        
        XR <- drop0(cbind(drop0(X), drop0(R_spline_design), drop0(R_design)))
        R_ridge <- bdiag(zeromat_beta, R_spline_ridge, R_ridge)
        
        moments_sigma_alpha <- mapply(vi_sigma_alpha, vi_sigma_alpha_nu, d_j, 
            SIMPLIFY = FALSE, FUN = function(phi, nu, d) {
              inv_phi <- solve(phi)
              sigma.inv <- nu * inv_phi
              ln.det <- log(det(phi)) - sum(digamma((nu - 1:d + 1) / 2)) - d * log(2)
              return(list(sigma.inv = sigma.inv, ln.det = ln.det))
            })
        
        if (family == 'linear'){# Rescale for linear
          XR <- XR * sqrt(vi_sigmasq_a/vi_sigmasq_b)
          adj_s <- s * sqrt(vi_sigmasq_a/vi_sigmasq_b)
          R_ridge <- R_ridge * vi_sigmasq_a/vi_sigmasq_b
          offset <- 0
        }else if (family == 'negbin'){
          adj_s <- s
          offset <- vi_r_mu
          stop('translation not set up for negative binomial.')
        }else if (family == 'binomial'){
          adj_s <- s
          offset <- 0
        }else{stop("family not set up for translation expansion.")}
        
        update_expansion_XR <- update_rho(
          XR = XR, y = adj_s, omega = diag_vi_pg_mean, 
          prior_precision = R_ridge, vi_beta_mean = vi_beta_mean,
          moments_sigma_alpha = moments_sigma_alpha,
          prior_sigma_alpha_nu = prior_sigma_alpha_nu, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
          vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp, vi_a_nu_jp = vi_a_nu_jp,
          vi_a_APRIOR_jp = vi_a_APRIOR_jp, 
          stationary_rho = stationary_rho,
          spline_REs = spline_REs, d_j = d_j,
          do_huangwand = do_huangwand, offset = offset,
          p.X = p.X, method = px_method, px_it = px_it,
          init_rho = opt_prior_rho
        )
        
        if (px_method %in% c('numerical_hw', 'profiled', 'dynamic')){
          px_improve <- update_expansion_XR$improvement
          opt_prior_rho <- update_expansion_XR$opt_par
          update_expansion_hw <- update_expansion_XR$hw
          update_expansion_XR <- update_expansion_XR$rho
        }else if (px_method %in% c('numerical', 'OSL')){
          px_improve <- update_expansion_XR$improvement
          opt_prior_rho <- update_expansion_XR <- update_expansion_XR$rho
        }
        opt_prior_rho <- NULL
        
        update_expansion_bX <- Matrix(update_expansion_XR[1:p.X])
        update_expansion_splines <- as.list(update_expansion_XR[-(1:p.X)][seq_len(sum(spline_REs))])
        
        if (any(!spline_REs)){
          update_expansion_R <- mapply(split(update_expansion_XR[-1:-(p.X + sum(spline_REs))], 
                                             rep(1:(number_of_RE - sum(spline_REs)), d_j[!spline_REs]^2)), d_j[!spline_REs], 
                                       SIMPLIFY = FALSE, FUN=function(i,d){matrix(i, nrow = d)})
        }
        
      }
      
      if (do_timing){
        toc(quiet = quiet_time, log = TRUE)
        tic('px_propose')
      }
      
      est_rho_all <- update_expansion_XR[-(1:p.X)]
      if (sum(spline_REs)){
        est_rho_spline <- est_rho_all[seq_len(sum(spline_REs))]
        est_rho <- est_rho_all[-seq_len(sum(spline_REs))]
      }else{
        est_rho <- est_rho_all
        est_rho_spline <- 1
      }
      
      if (px_method %in% c('numerical_hw', 'profiled', 'dynamic')){
        
        check_rho_hw <- unlist(vi_a_b_jp[c(which(spline_REs), which(!spline_REs))])
        check_rho_hw <- check_rho_hw - unlist(update_expansion_hw)
        names(check_rho_hw) <- NULL
        
      }else{
        check_rho_hw <- 0
      }
      if (!quiet_rho){
        print(round(c(est_rho_spline, est_rho, check_rho_hw), 5))
      }
      if (parameter_expansion == 'diagonal'){
        if (!is.na(px_improve) & (max(abs(est_rho - 1)) < 1e-6) & (max(abs(est_rho_spline - 1)) < 1e-6) ){
          if (!quiet_rho){print('No further improvements')}
          skip_translate <- TRUE
        }
      }else{
        if (length(est_rho) > 0){
          diff_rho <- max(abs(est_rho - stationary_rho))
        }else{diff_rho <- 0}
        if (!is.na(px_improve) & (diff_rho < 1e-6) & (max(abs(est_rho_spline - 1)) < 1e-6) ){
          if (!quiet_rho){print('No further improvements')}
          skip_translate <- TRUE
        }
        if (!is.na(px_improve)){
          if (abs(px_improve) < 1e-7){
            if (!quiet_rho){print('No further improvements (ELBO)')}
            skip_translate <- TRUE
          }
        }
      }
      
      if (sum(spline_REs) > 0){
        
        if (parameter_expansion == 'diagonal'){
          old_update_diag_R <- update_diag_R
          update_diag_R <- lapply(d_j, FUN=function(i){rep(1, i)})
          update_diag_R[!spline_REs] <- old_update_diag_R
        }
        
        if (any(!spline_REs)){
          old_update_expansion_R <- update_expansion_R
          update_expansion_R <- lapply(d_j, FUN=function(i){diag(i)})
          update_expansion_R[!spline_REs] <- old_update_expansion_R
          rm(old_update_expansion_R)
        }else{
          update_expansion_R <- as.list(rep(NA, length(spline_REs)))
        }
        update_expansion_R[spline_REs] <- lapply(update_expansion_splines, FUN=function(i){matrix(i)})
        
      }
      
      prop_vi_sigma_alpha <- mapply(vi_sigma_alpha, update_expansion_R, SIMPLIFY = FALSE,
        FUN=function(Phi, R){R %*% Phi %*% t(R)})
      
      # cat('r')
      # Are any of the estimated "R_j" have a negative determinant?
      sign_detRj <- sign(sapply(update_expansion_R, det))
      any_neg_det <- any(sign_detRj < 0)
      
      mapping_for_R_block <- make_mapping_alpha(update_expansion_R, px.R = TRUE)
      update_expansion_Rblock <- prepare_T(mapping = mapping_for_R_block, levels_per_RE = g_j, num_REs = number_of_RE,
                        variables_per_RE = d_j, running_per_RE = breaks_for_RE, cyclical = FALSE, px.R = TRUE)

      check_Rblock <- bdiag(mapply(update_expansion_R, g_j, FUN=function(i,g){bdiag(lapply(1:g, FUN=function(k){i}))}))
      if (max(abs(check_Rblock - update_expansion_Rblock)) != 0){
        warning('Error in creating parameter expansion; check that ELBO increases monotonically.')
      }
      
      update_expansion_R_logdet <- sapply(update_expansion_R, FUN=function(i){determinant(i)$modulus})
      
      prop_vi_beta_mean <- update_expansion_bX
      prop_vi_alpha_mean <- update_expansion_Rblock %*% vi_alpha_mean
      
      if (!quiet_rho){cat('r')}

      if (factorization_method != 'weak'){
        
        prop_log_det_joint_var <- prop_vi_joint_decomp <- NULL
        
        if (!any_neg_det){
          prop_vi_alpha_decomp <- vi_alpha_decomp %*% t(update_expansion_Rblock)
        }else{
          warning(paste0('Manually corrected R_j with negative determinant at iteration ', it))
          if (all(d_j == 1)){
            if (!isDiagonal(update_expansion_Rblock)){
              stop('Correction failed as R_j is not diagonal. Try requiring optimization of PX objective.')
            }
            diag(update_expansion_Rblock) <- abs(diag(update_expansion_Rblock))
            prop_vi_alpha_decomp <- vi_alpha_decomp %*% t(update_expansion_Rblock)
          }else{
            prop_vi_alpha_decomp <- update_expansion_Rblock %*% t(vi_alpha_decomp) %*% 
              vi_alpha_decomp %*% t(update_expansion_Rblock)
            prop_vi_alpha_decomp <- Matrix::Cholesky(prop_vi_alpha_decomp)
            prop_vi_alpha_decomp <- with(expand(prop_vi_alpha_decomp), t(L) %*% P)
          }
        }
        
        prop_log_det_alpha_var <- log_det_alpha_var + 2 * sum(update_expansion_R_logdet * g_j)
        prop_log_det_beta_var <- log_det_beta_var
        prop_vi_beta_decomp <- vi_beta_decomp
        
        prop_variance_by_alpha_jg <- calculate_expected_outer_alpha(
            L = prop_vi_alpha_decomp, 
            alpha_mu = as.vector(prop_vi_alpha_mean), 
            re_position_list = outer_alpha_RE_positions)
        prop_vi_sigma_outer_alpha <- prop_variance_by_alpha_jg$outer_alpha
      }else{
        stop('...')
        # Be sure to set up linear case here too..
      }
      
      if (do_huangwand){
        if (parameter_expansion == "diagonal"){
          prop_vi_a_b_jp <- mapply(vi_a_b_jp, update_diag_R, SIMPLIFY = FALSE,
                              FUN=function(i,j){i / j^2})
          if (px_method != 'OSL'){stop('Double check diagonal expansion')}
        }else{

          if (px_method %in% c('OSL')){
            
            prop_moments <- mapply(moments_sigma_alpha, update_expansion_R, SIMPLIFY = FALSE,
               FUN=function(Phi, R){
                 inv_R <- solve(R)
                 return(diag(t(inv_R) %*% Phi$sigma.inv %*% inv_R))
               })
            prop_vi_a_b_jp <- mapply(vi_a_nu_jp, vi_a_APRIOR_jp, prop_moments,
               SIMPLIFY = FALSE,
               FUN=function(nu, APRIOR, diag_j){
                 1/APRIOR^2 + nu * diag_j
               })
            
          }else if (px_method == 'numerical'){
            prop_vi_a_b_jp <- vi_a_b_jp
          }else{
            prop_vi_a_b_jp <- update_expansion_hw[names(vi_a_b_jp)]
          }
          

        }
      }else{
        prop_vi_a_b_jp <- NULL
      }
      
      # #L^T L = Variance
      # #R Var R^T --->
      # # L %*% R^T

      if (debug_px){
        prop.ELBO <- calculate_ELBO(family = family,
            ELBO_type = ELBO_type,
            factorization_method = factorization_method,
            d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
            prior_sigma_alpha_nu = prior_sigma_alpha_nu,
            iw_prior_constant = iw_prior_constant,
            X = X, Z = Z, s = s, y = y,
            vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
            vi_sigma_alpha_nu = vi_sigma_alpha_nu,
            
            vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
            vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
            
            vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma, vi_r_mu = vi_r_mu,
            
            vi_sigma_alpha = prop_vi_sigma_alpha, 
            vi_a_b_jp = prop_vi_a_b_jp,
            vi_sigma_outer_alpha = prop_vi_sigma_outer_alpha,
            vi_beta_mean = prop_vi_beta_mean, vi_alpha_mean = prop_vi_alpha_mean,
            
            log_det_beta_var = prop_log_det_beta_var, 
            log_det_alpha_var = prop_log_det_alpha_var,
            log_det_joint_var = prop_log_det_joint_var,
            
            vi_beta_decomp = prop_vi_beta_decomp, 
            vi_alpha_decomp = prop_vi_alpha_decomp,
            vi_joint_decomp = prop_vi_joint_decomp,
            
            do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, 
            vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp,
            choose_term
        )
      }
      if (!quiet_rho){cat('d')}
      
      # If debugging, check whether the change in ELBO
      # from the profiled objective agrees with the 
      # change from the actual ELBO.
      if (debug_px){
        ELBO_diff <- prop.ELBO$ELBO - prior.ELBO$ELBO
        if (!is.na(px_improve)){
          if (abs(ELBO_diff - px_improve) > sqrt(.Machine$double.eps)){
            stop('PX does not agree with debug.')
            # browser()
            # stop()
          }
        }else{
          if (!isTRUE(all.equal(ELBO_diff, 0))){stop('PX altered parameters when NA.')}
        }
        debug_PX_ELBO[it] <- ELBO_diff
      }
      
      if (is.na(px_improve)){
        accept.PX <- FALSE
      }else if (px_improve > 0){
        accept.PX <- TRUE
      }else{
        accept.PX <- FALSE
      }

      if (accept.PX){
        
        # Accept the PX-VB adjustment
        
        vi_beta_mean <- prop_vi_beta_mean
        vi_alpha_mean <- prop_vi_alpha_mean
        vi_sigma_alpha <- prop_vi_sigma_alpha
        if (factorization_method == 'weak'){
          stop('Setup reassignment weak')
          if (do_SQUAREM){stop('...')}
        }else{
          vi_alpha_decomp <- prop_vi_alpha_decomp
          log_det_alpha_var <- prop_log_det_alpha_var
          if (do_SQUAREM){
            vi_alpha_L_nonpermute <- vi_alpha_decomp
            vi_alpha_LP <- Diagonal(n = ncol(vi_alpha_decomp))
          }
        }
        variance_by_alpha_jg <- prop_variance_by_alpha_jg
        vi_sigma_outer_alpha <- prop_vi_sigma_outer_alpha
        
        if (do_huangwand){
          vi_a_b_jp <- prop_vi_a_b_jp
        }
      }
      
      if (!quiet_rho){
        print(accept.PX)
        if (debug_px){
          out_px <- c(prop.ELBO$ELBO, prior.ELBO$ELBO)
          names(out_px) <- c('PX', 'prior')
          print(out_px)
        }
      }
      
      if (isFALSE(accept.PX) & (px_method %in% c('numerical', 'profiled', 'numerical_hw'))){stop("PX SHOULD NOT FAIL")}
      
      accepted_times <- accept.PX + accepted_times

      if (do_timing){
        toc(quiet = quiet_time, log = TRUE)
      }
      rm(prop_vi_beta_mean, prop_vi_alpha_mean, prop_vi_sigma_alpha, prop_vi_alpha_decomp,
         prop_log_det_alpha_var, prop_variance_by_alpha_jg, prop_vi_sigma_outer_alpha)


      rownames(vi_alpha_mean) <- fmt_names_Z
    }
    
    # Adjust the terms in the ELBO calculation that are different.
    
    final.ELBO <- calculate_ELBO(family = family,
      ELBO_type = ELBO_type,
      factorization_method = factorization_method,
      d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
      prior_sigma_alpha_nu = prior_sigma_alpha_nu,
      iw_prior_constant = iw_prior_constant,
      X = X, Z = Z, s = s, y = y,
      vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
      vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
      vi_sigma_outer_alpha = vi_sigma_outer_alpha,
      vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
      log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
      vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
      vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
      vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
      vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
      log_det_joint_var = log_det_joint_var, vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma,
      do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
      vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
    )

    if (do_timing) {
      toc(quiet = quiet_time, log = TRUE)
      tic("Update Squarem")
    }
    
    if (do_SQUAREM){
      
      if (factorization_method %in% c('weak', 'collapsed')){
        vi_alpha_L_nonpermute <- vi_beta_L_nonpermute <- NULL
        vi_alpha_LP <- vi_beta_LP <- NULL
      }else{
        vi_joint_L_nonpermute <- vi_joint_LP <- NULL
      }
      
      squarem_list[[squarem_counter]] <- namedList(vi_sigma_alpha_nu, 
           vi_sigma_alpha, vi_alpha_mean, vi_beta_mean,
           vi_pg_c, vi_alpha_L_nonpermute, vi_alpha_LP,
           vi_beta_L_nonpermute, vi_beta_LP,
           vi_alpha_L_nonpermute,
           vi_joint_L_nonpermute, vi_joint_LP,
           vi_a_a_jp, vi_a_b_jp,
           vi_r_mu, vi_r_sigma, vi_r_mean)
      
      if (family == 'linear'){
        squarem_list[[squarem_counter]]$vi_sigmasq_a <- vi_sigmasq_a
        squarem_list[[squarem_counter]]$vi_sigmasq_b <- vi_sigmasq_b
      }
      if (squarem_counter %% 3 == 0){
        

        ELBOargs <- list(family = family,
           ELBO_type = ELBO_type,
           factorization_method = factorization_method,
           d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
           prior_sigma_alpha_nu = prior_sigma_alpha_nu,
           iw_prior_constant = iw_prior_constant,
           X = X, Z = Z, s = s, y = y,
           vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
           vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
           vi_sigma_outer_alpha = vi_sigma_outer_alpha,
           vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
           log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
           vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
           vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
           vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
           vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
           log_det_joint_var = log_det_joint_var, 
           vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma,
           do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
           vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
        )
        
        if (factorization_method %in% c('weak', 'collapsed')){
          squarem_par <- c('vi_a_b_jp', 'vi_sigma_alpha', 'vi_pg_c',
                           'vi_alpha_mean', 'vi_beta_mean', 'vi_joint_L_nonpermute')
          squarem_type <- c('positive', 'matrix', 'positive',
                            'real', 'real', 'cholesky')
          squarem_structure <- c('list', 'list', 'vector', 'vector', 'vector',
                                 'vector')
        }else{
          squarem_par <- c('vi_a_b_jp', 'vi_sigma_alpha', 'vi_pg_c',
                           'vi_alpha_mean', 'vi_beta_mean', 'vi_beta_L_nonpermute',
                           'vi_alpha_L_nonpermute')
          squarem_type <- c('positive', 'matrix', 'positive',
                            'real', 'real', 'cholesky', 'cholesky')
          squarem_structure <- c('list', 'list', 'vector', 'vector', 'vector',
                                 'vector', 'vector')
          
        }
        
        if (family == 'negbin'){
          
          stop('Setup SQUAREM For negbin')
          if (vi_r_method == 'VEM'){
            
            squarem_par <- c(squarem_par, 'vi_r_mu')
            squarem_type <- c(squarem_type, 'real')
            squarem_structure <- c(squarem_structure, 'vector')
            
          } else if (vi_r_method %in% c('Laplace', 'delta')) {
            
            stop('Set up Laplace/delta for SQUAREM')
            
            squarem_par <- c(squarem_par, 'vi_r_mu', 'vi_r_sigma')
            squarem_type <- c(squarem_type, 'real', 'positive')
            squarem_structure <- c(squarem_structure, 'vector', 'vector')
            
          } else if (vi_r_method == 'fixed') {
            
          }
        }
        if (family %in% 'linear'){
          squarem_par <- c(squarem_par, 'vi_sigmasq_b')
          squarem_type <- c(squarem_type, 'positive')
          squarem_structure <- c(squarem_structure, 'vector')
        }
        
        remove_hw_b <- FALSE
        if (!do_huangwand){
          squarem_type <- squarem_type[!grepl(squarem_par, pattern='vi_a_b_jp')]
          squarem_structure <- squarem_structure[!grepl(squarem_par, pattern='vi_a_b_jp')]
          squarem_par <- squarem_par[!grepl(squarem_par, pattern='vi_a_b_jp')]
        }else{
          if (remove_hw_b){
            squarem_type <- squarem_type[!grepl(squarem_par, pattern='vi_a_b_jp')]
            squarem_structure <- squarem_structure[!grepl(squarem_par, pattern='vi_a_b_jp')]
            squarem_par <- squarem_par[!grepl(squarem_par, pattern='vi_a_b_jp')]
          }
        }
        
        remove_c <- FALSE
        if (remove_c | family %in% 'linear'){
          squarem_type <- squarem_type[!grepl(squarem_par, pattern='vi_pg_c')]
          squarem_structure <- squarem_structure[!grepl(squarem_par, pattern='vi_pg_c')]
          squarem_par <- squarem_par[!grepl(squarem_par, pattern='vi_pg_c')]
        }
        
        check_tri <- sapply(squarem_par[squarem_type == 'cholesky'], FUN=function(nm_i){
          si <- sapply(squarem_list, FUN=function(i){isTriangular(i[[nm_i]])})
          return(all(si))
        })
        squarem_type[squarem_type == 'cholesky'][check_tri == FALSE] <- 'lu'
        
        if ('vi_pg_c' %in% squarem_par){
          # Address possibility of "zero" for vi_pg_c
          squarem_list <- lapply(squarem_list, FUN=function(i){
            i$vi_pg_c <- ifelse(abs(i$vi_pg_c) < 1e-6, 1e-6, i$vi_pg_c)
            return(i)
          })
        }
        
        squarem_list <- lapply(squarem_list, FUN=function(i){
          
          i[squarem_par] <- mapply(squarem_par, squarem_type, 
            squarem_structure, SIMPLIFY = FALSE, FUN=function(s_par, s_type, s_str){
              if (s_str == 'vector'){
                out <- squarem_prep_function(i[[s_par]], s_type) 
              }else{
                out <- lapply(i[[s_par]], FUN=function(j){squarem_prep_function(j, s_type)})
              }
              return(out)
          })
          return(i)
        })
        
        prep_SQUAREM <- mapply(squarem_par, squarem_structure, squarem_type,
          SIMPLIFY = FALSE, FUN=function(s_par, s_str, s_type){
            
            if (s_type == 'lu'){
              
              r <- list(
                'L' = squarem_list[[2]][[s_par]]$L - squarem_list[[1]][[s_par]]$L,
                'U' = squarem_list[[2]][[s_par]]$U - squarem_list[[1]][[s_par]]$U
              )
              d2 <- list(
                'L' = squarem_list[[3]][[s_par]]$L - squarem_list[[2]][[s_par]]$L,
                'U' = squarem_list[[3]][[s_par]]$U - squarem_list[[2]][[s_par]]$U
              )
              v <- list("L" = d2$L - r$L, 'U' = d2$U - r$U)
              norm_sq_r <- sum(sapply(r, FUN=function(i){sum(i@x^2)}))
              norm_sq_v <- sum(sapply(v, FUN=function(i){sum(i@x^2)}))
              max_d <- max(sapply(d2, FUN=function(i){max(abs(i@x))}))
              P <- squarem_list[[3]][[s_par]]$P
              Q <- squarem_list[[3]][[s_par]]$Q
            }else if (s_str == 'list'){
              r <- mapply(squarem_list[[2]][[s_par]], squarem_list[[1]][[s_par]], SIMPLIFY = FALSE, FUN=function(i,j){i - j})
              d2 <- mapply(squarem_list[[3]][[s_par]], squarem_list[[2]][[s_par]], SIMPLIFY = FALSE, FUN=function(i,j){i - j})
              v <- mapply(d2, r, SIMPLIFY = FALSE, FUN=function(i,j){i - j})
              norm_sq_r <- sum(unlist(lapply(r, as.vector))^2)
              norm_sq_v <- sum(unlist(lapply(v, as.vector))^2)
              max_d <- max(abs(sapply(d2, FUN=function(j){max(abs(j))})))
              P <- NULL
              Q <- NULL
            }else{
              r <- squarem_list[[2]][[s_par]] - squarem_list[[1]][[s_par]]
              d2 <- squarem_list[[3]][[s_par]] - squarem_list[[2]][[s_par]]
              v <- d2 - r
              norm_sq_r <- sum(r^2)
              norm_sq_v <- sum(v^2)
              max_d = max(abs(d2))
              P <- NULL
              Q <- NULL
            }
          return(list(first = squarem_list[[1]][[s_par]], 
                      second = squarem_list[[2]][[s_par]], 
                      max_d = max_d, P = P, Q = Q,
                      r = r, v = v, norm_sq_r = norm_sq_r, norm_sq_v = norm_sq_v))
        })
        
        ind_alpha <- FALSE
        
        if (ind_alpha){
          alpha <- -sqrt((sapply(prep_SQUAREM, FUN=function(i){i$norm_sq_r}))) /
            sqrt((sapply(prep_SQUAREM, FUN=function(i){i$norm_sq_v})))
          if (any(alpha > -1)){
            alpha[which(alpha > -1)] <- -1.01
          }
          if (any(alpha < -10)){
            alpha[which(alpha < -10)] <- -10
          }
          max_d <- sapply(prep_SQUAREM, FUN=function(i){i$max_d})
          if (any(max_d < tolerance_parameters)){
            alpha[which(max_d < tolerance_parameters)] <- -1.01
          }
        }else{
          
          alpha <- -sqrt(sum(sapply(prep_SQUAREM, FUN=function(i){i$norm_sq_r}))) /
            sqrt(sum(sapply(prep_SQUAREM, FUN=function(i){i$norm_sq_v})))
          
          if (alpha > -1){
            alpha <- -1.01
          }
          if (alpha < -10){
            alpha <- -10
          }
          
          alpha <- rep(alpha, length(prep_SQUAREM))
          names(alpha) <- names(prep_SQUAREM)
        }
        if (!quiet_rho){print(alpha)}
        
        orig_squarempar <- squarem_par
        orig_alpha <- alpha
        for (attempt_SQUAREM in 1:10){
          
          if (!quiet_rho){print(mean(alpha))}
          
          squarem_par <- orig_squarempar
          if (attempt_SQUAREM > 1){
            alpha <- (alpha - 1)/2
          }
          
          prop_squarem <- mapply(prep_SQUAREM, squarem_structure, squarem_type, alpha, SIMPLIFY = FALSE,
             FUN=function(i, s_str, s_type, s_alpha){
               if (s_type == 'lu'){
                 prop_squarem <- lapply(c('L', 'U'), FUN=function(k){
                   i$first[[k]] - 2 * s_alpha * i$r[[k]] + s_alpha^2 * i$v[[k]] 
                 })
                 names(prop_squarem) <- c('L', 'U')
                 prop_squarem$P <- i$P
                 prop_squarem$Q <- i$Q
                 if (!quiet_rho){if (!isTRUE(all.equal(i$second$P, i$P))){print('MISALIGNED at P')}}
                 if (!quiet_rho){if (!isTRUE(all.equal(i$second$Q, i$Q))){print('MISALIGNED at Q')}}
               }else if (s_str == 'list'){
                 prop_squarem <- mapply(i$first, i$second, 
                    i$r, i$v, SIMPLIFY = FALSE, FUN=function(i_1, s_1, r_1, v_1){
                      out <- i_1 - 2 * s_alpha * r_1 + s_alpha^2 * v_1
                      return(out)
                    })
                 names(prop_squarem) <- names(i$first)
               }else{
                 prop_squarem <- i$first - 2 * s_alpha * i$r + s_alpha^2 * i$v
               }
               return(prop_squarem)
             })
          
          names(prop_squarem) <- squarem_par
          
          prop_ELBOargs <- ELBOargs
          
          prop_squarem <- mapply(prop_squarem, squarem_type, 
           squarem_structure, SIMPLIFY = FALSE, FUN=function(i, s_type, s_str){
             if (s_str == 'vector'){
               out <- squarem_unprep_function(i, s_type) 
             }else{
               out <- lapply(i, FUN=function(j){squarem_unprep_function(j, s_type)})
             }
             return(out)
           })
          
          if (factorization_method == 'weak'){
            
            if (squarem_type[squarem_par == 'vi_joint_L_nonpermute'] == 'lu'){
              prop_squarem$vi_joint_decomp <- prop_squarem$vi_joint_L_nonpermute$M              
              prop_ELBOargs$log_det_joint_var <- prop_squarem$vi_joint_L_nonpermute$logdet_M
            }else{
              prop_squarem$vi_joint_decomp <- prop_squarem$vi_joint_L_nonpermute %*% t(squarem_list[[1]]$vi_joint_LP)
              prop_ELBOargs$log_det_joint_var <- 2 * sum(log(diag(prop_squarem$vi_joint_L_nonpermute)))
            }
            prop_squarem$vi_alpha_decomp <- prop_squarem$vi_joint_decomp[, -1:-p.X, drop = F]
            prop_squarem$vi_beta_decomp <- prop_squarem$vi_joint_decomp[, 1:p.X, drop = F]
            
            squarem_par <- c(squarem_par, 'log_det_joint_var')
            squarem_par <- c(squarem_par, 'vi_joint_decomp')
            
          }else if (factorization_method == 'collapsed'){
            stop('Setup squarem for collapsed')
            if (squarem_type[squarem_par == 'vi_joint_L_nonpermute'] == 'lu'){
              prop_squarem$vi_joint_decomp <- prop_squarem$vi_joint_L_nonpermute$M              
              prop_ELBOargs$log_det_joint_var <- prop_squarem$vi_joint_L_nonpermute$logdet_M
            }else{
              prop_squarem$vi_joint_decomp <- prop_squarem$vi_joint_L_nonpermute %*% t(squarem_list[[1]]$vi_joint_LP)
              prop_ELBOargs$log_det_joint_var <- 2 * sum(log(diag(prop_squarem$vi_joint_L_nonpermute)))
            }
            prop_squarem$vi_alpha_decomp <- prop_squarem$vi_joint_decomp[, -1:-p.X, drop = F]
            prop_squarem$vi_beta_decomp <- prop_squarem$vi_joint_decomp[, 1:p.X, drop = F]
            
            squarem_par <- c(squarem_par, 'log_det_joint_var')
            squarem_par <- c(squarem_par, 'vi_joint_decomp')
            
            
          }else{

            if (squarem_type[squarem_par == 'vi_beta_L_nonpermute'] == 'lu'){
              prop_ELBOargs$log_det_beta_var <- prop_squarem$vi_beta_L_nonpermute$logdet_M
              prop_squarem$vi_beta_decomp <- prop_squarem$vi_beta_L_nonpermute$M
            }else{
              prop_ELBOargs$log_det_beta_var <- 2 * sum(log(diag(prop_squarem$vi_beta_L_nonpermute)))
              prop_squarem$vi_beta_decomp <- prop_squarem$vi_beta_L_nonpermute %*% t(squarem_list[[1]]$vi_beta_LP)
            }

            if (squarem_type[squarem_par == 'vi_alpha_L_nonpermute'] == 'lu'){
              prop_ELBOargs$log_det_alpha_var <- prop_squarem$vi_alpha_L_nonpermute$logdet_M
              prop_squarem$vi_alpha_decomp <- prop_squarem$vi_alpha_L_nonpermute$M              
            }else{
              prop_ELBOargs$log_det_alpha_var <- 2 * sum(log(diag(prop_squarem$vi_alpha_L_nonpermute)))
              prop_squarem$vi_alpha_decomp <- prop_squarem$vi_alpha_L_nonpermute %*% t(squarem_list[[1]]$vi_alpha_LP)
            }
            
            squarem_par <- c(squarem_par, 'log_det_alpha_var', 'log_det_beta_var')
            squarem_par <- c(squarem_par, 'vi_alpha_decomp', 'vi_beta_decomp')
            
          }
          
          if (family == 'negbin'){
            
            if (vi_r_method == 'VEM'){
              
              prop_ELBOargs$vi_r_mean <- exp(prop_squarem$vi_r_mu)

            } else if (vi_r_method %in% c('Laplace', 'delta')){
              
              prop_ELBOargs$vi_r_mean <- exp(prop_squarem$vi_r_mu + prop_squarem$vi_r_sigma/2)
            
            } 
            
            if (factorization_method != 'weak'){
              prop_joint_var <- rowSums((X %*% t(prop_squarem$vi_beta_decomp))^2) + 
                rowSums((Z %*% t(prop_squarem$vi_alpha_decomp))^2)
            }else{
              prop_joint_var <-  cpp_zVz(Z = joint.XZ, 
                V = as(prop_squarem$vi_joint_decomp, "generalMatrix")) 
            }
            
            if (vi_r_method %in% c('Laplace', 'delta')){
              prop_joint_var <- prop_joint_var + vi_r_sigma
            }

            prop_ELBOargs$vi_pg_c <- sqrt(as.vector(X %*% prop_squarem$vi_beta_mean + Z %*% prop_squarem$vi_alpha_mean - prop_squarem$vi_r_mu)^2 + prop_joint_var)
            prop_ELBOargs$vi_pg_b <- y + prop_ELBOargs$vi_r_mean
            
          }
          
          
          for (v in names(prop_squarem)){
            prop_ELBOargs[[v]] <- prop_squarem[[v]]
          }
          
          prop_ELBOargs$vi_alpha_L_nonpermute <- NULL
          prop_ELBOargs$vi_beta_L_nonpermute <- NULL
          prop_ELBOargs$vi_joint_L_nonpermute <- NULL

          if ('vi_alpha_mean' %in% names(prop_squarem)){
            prop_variance_by_alpha_jg <- calculate_expected_outer_alpha(
              L = (prop_squarem$vi_alpha_decomp), 
              alpha_mu = as.vector(prop_squarem$vi_alpha_mean), 
              re_position_list = outer_alpha_RE_positions)
            prop_ELBOargs[['vi_sigma_outer_alpha']] <- prop_variance_by_alpha_jg$outer_alpha
            squarem_par <- c(squarem_par, 'vi_sigma_outer_alpha')
          }
          
          if (remove_hw_b){
            prop_diag_Einv_sigma <- mapply(prop_ELBOargs$vi_sigma_alpha, 
                                           vi_sigma_alpha_nu, d_j, SIMPLIFY = FALSE, FUN = function(phi, nu, d) {
                                             inv_phi <- solve(phi)
                                             sigma.inv <- nu * inv_phi
                                             return(diag(sigma.inv))
                                           })
            prop_ELBOargs$vi_a_b_jp <- mapply(vi_a_nu_jp, vi_a_APRIOR_jp, prop_diag_Einv_sigma,
                                              SIMPLIFY = FALSE,
                                              FUN=function(nu, APRIOR, diag_j){
                                                1/APRIOR^2 + nu * diag_j
                                              })
            squarem_par <- c(squarem_par, 'vi_a_b_jp')
          }

          if ('vi_pg_c' %in% squarem_par){
            
            if (family %in% 'binomial'){
              prop_vi_pg_mean <- prop_ELBOargs$vi_pg_b / (2 * prop_ELBOargs$vi_pg_c) * tanh(prop_ELBOargs$vi_pg_c / 2)
              fill_zero <- which(abs(prop_ELBOargs$vi_pg_c) < 1e-6)
              if (length(fill_zero) > 0){
                prop_vi_pg_mean[fill_zero] <- prop_ELBOargs$vi_pg_b[fill_zero]/4
              }
              prop_ELBOargs[['vi_pg_mean']] <- prop_vi_pg_mean
              squarem_par <- c(squarem_par, 'vi_pg_mean')
            }else{stop('Set up SQUAREM for other family')}

          }else if (!(family %in% 'linear')){
            
            if (family != 'binomial'){stop('check squarem for non-binomial case')}
            
            if (factorization_method %in% c("weak", "collapsed")) {
              joint_quad <- cpp_zVz(Z = joint.XZ, 
                  V = as(prop_ELBOargs$vi_joint_decomp, "generalMatrix")) 
              if (family == 'negbin'){
                joint_quad <- joint_quad + prop_ELBOargs$vi_r_sigma
              }
              prop_ELBOargs$vi_pg_c <- sqrt(as.vector(X %*% prop_ELBOargs$vi_beta_mean + Z %*% prop_ELBOargs$vi_alpha_mean - prop_ELBOargs$vi_r_mu)^2 + joint_quad)
            } else {
              beta_quad <- rowSums((X %*% t(prop_ELBOargs$vi_beta_decomp))^2)
              alpha_quad <- rowSums((Z %*% t(prop_ELBOargs$vi_alpha_decomp))^2)
              joint_var <- beta_quad + alpha_quad
              if (family == 'negbin'){
                joint_var <- joint_var + prop_ELBOargs$vi_r_sigma
              }
              prop_ELBOargs$vi_pg_c <- sqrt(as.vector(X %*% prop_ELBOargs$vi_beta_mean + Z %*% prop_ELBOargs$vi_alpha_mean - prop_ELBOargs$vi_r_mu)^2 + joint_var)
            }
            
            prop_vi_pg_mean <- prop_ELBOargs$vi_pg_b / (2 * prop_ELBOargs$vi_pg_c) * tanh(prop_ELBOargs$vi_pg_c / 2)
            fill_zero <- which(abs(prop_ELBOargs$vi_pg_c) < 1e-6)
            if (length(fill_zero) > 0){
              prop_vi_pg_mean[fill_zero] <- prop_ELBOargs$vi_pg_b[fill_zero]/4
            }
            
            prop_ELBOargs[['vi_pg_mean']] <- prop_vi_pg_mean
            squarem_par <- c(squarem_par, 'vi_pg_c', 'vi_pg_mean')
            
          }
          
          elbo_init <- do.call("calculate_ELBO", ELBOargs)
          elbo_squarem <- do.call("calculate_ELBO", prop_ELBOargs)
          if (!quiet_rho){print(c(elbo_squarem$ELBO, elbo_init$ELBO))}
          if (elbo_squarem$ELBO >= elbo_init$ELBO){break}
            
        }
        
        if (elbo_squarem$ELBO >= elbo_init$ELBO){
          if (!quiet_rho){cat('SUCCESS')}
          squarem_success <- squarem_success + 1
          squarem.ELBO <- elbo_squarem
          final.ELBO <- elbo_squarem
          
          for (v in squarem_par){
            assign(v, prop_ELBOargs[[v]])
          }
          
          test_ELBO <- calculate_ELBO(family = family,
               ELBO_type = ELBO_type,
               factorization_method = factorization_method,
               d_j = d_j, g_j = g_j, prior_sigma_alpha_phi = prior_sigma_alpha_phi,
               prior_sigma_alpha_nu = prior_sigma_alpha_nu,
               iw_prior_constant = iw_prior_constant,
               X = X, Z = Z, s = s, y = y,
               vi_pg_b = vi_pg_b, vi_pg_mean = vi_pg_mean, vi_pg_c = vi_pg_c,
               vi_sigma_alpha = vi_sigma_alpha, vi_sigma_alpha_nu = vi_sigma_alpha_nu,
               vi_sigma_outer_alpha = vi_sigma_outer_alpha,
               vi_beta_mean = vi_beta_mean, vi_alpha_mean = vi_alpha_mean,
               log_det_beta_var = log_det_beta_var, log_det_alpha_var = log_det_alpha_var,
               vi_beta_decomp = vi_beta_decomp, vi_alpha_decomp = vi_alpha_decomp,
               vi_joint_decomp = vi_joint_decomp, choose_term = choose_term,
               vi_sigmasq_a = vi_sigmasq_a, vi_sigmasq_b = vi_sigmasq_b, 
               vi_sigmasq_prior_a = vi_sigmasq_prior_a, vi_sigmasq_prior_b = vi_sigmasq_prior_b,
               log_det_joint_var = log_det_joint_var, vi_r_mu = vi_r_mu, vi_r_mean = vi_r_mean, vi_r_sigma = vi_r_sigma,
               do_huangwand = do_huangwand, vi_a_a_jp = vi_a_a_jp, vi_a_b_jp = vi_a_b_jp,
               vi_a_nu_jp = vi_a_nu_jp, vi_a_APRIOR_jp = vi_a_APRIOR_jp
          )
          if (test_ELBO$ELBO != elbo_squarem$ELBO){stop('....')}
        }else{
          if (!quiet_rho){cat('FAIL')}
          squarem_success[1] <- squarem_success[1] + 1
          final.ELBO <- squarem.ELBO <- final.ELBO
        }
        squarem_list <- list()
        squarem_counter <- 1
      }else{
        
        squarem_counter <- squarem_counter + 1
        
      }
      
    }
    
    if (do_timing) {
      toc(quiet = quiet_time, log = T)
      tic("Final Cleanup")
    }

    if (debug_ELBO & it != 1) {
      debug_ELBO.1$step <- 1
      debug_ELBO.2$step <- 2
      debug_ELBO.3$step <- 3
      if (do_SQUAREM & (it %% 3 == 0)){
        squarem.ELBO$step <- 4
        final.ELBO$step <- 5
        update_ELBO <- rbind(debug_ELBO.1, debug_ELBO.2, debug_ELBO.3, squarem.ELBO, final.ELBO)
      }else{
        final.ELBO$step <- 4
        update_ELBO <- rbind(debug_ELBO.1, debug_ELBO.2, debug_ELBO.3, final.ELBO)
      }
      update_ELBO$it <- it
      store_ELBO <- rbind(store_ELBO, update_ELBO)
    } else {
      final.ELBO$step <- NA
      final.ELBO$it <- it
      store_ELBO <- rbind(store_ELBO, final.ELBO)
    }

    # if (!quiet_rho){
    #   if (factorization_method == 'weak'){
    #     print('NonsparseA')
    #     print(length(vi_joint_decomp@x))
    #   }else{
    #     print('NonsparseA')
    #     print(length(vi_alpha_decomp@x))
    #   }
    # }

    ## Change diagnostics
    
    change_elbo <- final.ELBO$ELBO - lagged_ELBO

    change_alpha_mean <- max(abs(vi_alpha_mean - lagged_alpha_mean))
    change_beta_mean <- max(abs(vi_beta_mean - lagged_beta_mean))
    
    unlist_vi <- c(unlist(lapply(vi_sigma_alpha, as.vector)), unlist(vi_a_b_jp))
    if (debug_ELBO){
      svi <- data.frame(t(as.vector(unlist_vi)))
      svi$it <- it
      store_vi <- rbind(store_vi,  svi)
    }
    
    change_sigma_mean <- mapply(vi_sigma_alpha, lagged_sigma_alpha, FUN = function(i, j) {
      max(abs(i - j))
    })

    if (factorization_method %in% c("weak", "collapsed")) {
      change_joint_var <- 0 # change_joint_var <- max(abs(vi_joint_decomp - lagged_joint_decomp))
      change_alpha_var <- change_beta_var <- 0
    } else {
      change_joint_var <- 0
      change_alpha_var <- max(abs(vi_alpha_decomp - lagged_alpha_decomp))
      change_beta_var <- max(abs(vi_beta_decomp - lagged_beta_decomp))
    }

    change_vi_r_mu <- vi_r_mu - lagged_vi_r_mu

    if (do_timing) {
      toc(quiet = quiet_time, log = T)
    }
    if (debug_param) {
      store_beta[it, ] <- as.vector(vi_beta_mean)
      store_alpha[it, ] <- as.vector(vi_alpha_mean)
      if (do_huangwand){
        store_hw[it,] <- unlist(vi_a_b_jp)
        colnames(store_hw) <- names(unlist(vi_a_b_jp))
      }
      store_sigma[it,] <- unlist(lapply(vi_sigma_alpha, as.vector))
      colnames(store_sigma) <- names(unlist(lapply(vi_sigma_alpha, as.vector)))
      
    }
    
    change_all <- data.frame(change_alpha_mean, change_beta_mean, 
        t(change_sigma_mean), change_alpha_var, change_beta_var, change_joint_var, change_vi_r_mu)

    if ((max(change_all) < tolerance_parameters) | (change_elbo > 0 & change_elbo < tolerance_elbo)) {
      if (!quiet) {
        message(paste0("Converged after ", it, " iterations with ELBO change of ", round(change_elbo, 1 + abs(floor(log(tolerance_elbo) / log(10))))))
        message(paste0("The largest change in any variational parameter was ", round(max(change_all), 1 + abs(floor(log(tolerance_parameters) / log(10))))))
      }
      break
    }
    if (debug_ELBO){
      change_all$it <- it
      store_parameter_traj <- rbind(store_parameter_traj, change_all)
    }
    if (!quiet & (it %% print_prog == 0)) {
      message(paste0("ELBO Change: ", round(change_elbo, 10)))
      message(paste0("Other Parameter Changes: ", max(change_all)))
    }

    lagged_alpha_mean <- vi_alpha_mean
    lagged_beta_mean <- vi_beta_mean
    lagged_alpha_decomp <- vi_alpha_decomp
    lagged_beta_decomp <- vi_beta_decomp
    lagged_sigma_alpha <- vi_sigma_alpha
    lagged_vi_r_mu <- vi_r_mu
    lagged_ELBO <- final.ELBO$ELBO
  }
  if (it == iterations) {
    message(paste0("Ended without Convergence after ", it, " iterations : ELBO change of ", round(change_elbo[1], abs(floor(log(tolerance_elbo) / log(10))))))
  }

  if (parameter_expansion %in% c("translation", "diagonal")) {
    final.ELBO$accepted_PX <- accepted_times / attempted_expansion
  }

  rownames(vi_beta_mean) <- colnames(X)
  
  output <- list(
    beta = list(mean = vi_beta_mean),
    ELBO = final.ELBO, 
    ELBO_trajectory = store_ELBO,
    sigma = list(cov = vi_sigma_alpha, df = vi_sigma_alpha_nu),
    alpha = list(mean = vi_alpha_mean)
  )
  if (family == 'linear'){
    output$sigmasq <- list(a = vi_sigmasq_a, b = vi_sigmasq_b)
  }else if (family == 'negbin'){
    
  }
  output$family <- family
  output$control <- control

  if (do_timing) {
    tic_log <- tictoc::tic.log(format = FALSE)
    tic_log <- data.frame(stage = sapply(tic_log, FUN = function(i) {
      i$msg
    }), time = sapply(tic_log, FUN = function(i) {
      i$toc - i$tic
    }), stringsAsFactors = F)

    tic.clear()
    tic.clearlog()

    tic_summary <- lapply(split(tic_log$time, tic_log$stage),
      FUN=function(i){
        data.frame(n = length(i), mean = mean(i), min = min(i), max = max(i),
                   total = sum(i))
      }
    )
    tic_summary <- do.call('rbind', tic_summary)
    tic_summary$variable <- rownames(tic_summary)
    rownames(tic_summary) <- NULL
  } else {
    tic_summary <- NULL
  }
  if (debug_param) {
    
    store_beta <- store_beta[1:it,,drop=F]
    store_alpha <- store_alpha[1:it,,drop=F]
    if (do_huangwand){
      store_hw <- store_hw[1:it,,drop=F]
    }else{store_hw <- NULL}
    store_sigma <- store_sigma[1:it,,drop=F]
    output$parameter_trajectory <- list(beta = store_beta,
                                        alpha = store_alpha,
                                        sigma = store_sigma,
                                        hw = store_hw)
  }
  if (factorization_method %in% c("weak", "collapsed")) {
    output$joint <- list(decomp_var = vi_joint_decomp)
  }
  if (control$return_data) {
    output$data <- list(X = X, Z = Z, y = y)
    if (family == 'binomial'){
      output$data$trials <- trials
    }
  }
  
  output$formula <- list(formula = formula, 
     re = re_fmla, fe = fe_fmla,
     interpret_gam = parse_formula,
     tt = tt, fe_Xlevels = fe_Xlevels,
     fe_contrasts = fe_contrasts, fe_terms = fe_terms)
  
  output$alpha$dia.var <- unlist(lapply(variance_by_alpha_jg$variance_jg, FUN = function(i) {
    as.vector(sapply(i, diag))
  }))
  output$beta$var <- t(vi_beta_decomp) %*% vi_beta_decomp
  output$beta$decomp_var <- vi_beta_decomp

  if (family == "negbin") {
    output$ln_r <- list(mu = vi_r_mu, sigma = vi_r_sigma, method = vi_r_method)
  }
  if (do_huangwand){
    output$hw <- list(a = vi_a_a_jp, b = vi_a_b_jp)
  }

  output$internal_parameters <- list(
    it_used = it, it_max = iterations,
    cyclical_pos = cyclical_pos,
    lp = as.vector(X %*% vi_beta_mean + Z %*% vi_alpha_mean - vi_r_mu),
    parameter.change = change_all,
    parameter.vi = store_vi,
    parameter.path = store_parameter_traj,
    spline = list(attr = Z.spline.attr, size = Z.spline.size),
    missing_obs = missing_obs, N = nrow(X),
    acceleration = list(accept.PX = accept.PX, 
      squarem_success = squarem_success, debug_PX_ELBO = debug_PX_ELBO),
    names_of_RE = names_of_RE, d_j = d_j, g_j = g_j
  )

  MAVB_parameters <- list(
    M_mu_to_beta = M_mu_to_beta,
    M_prime = M_prime,
    M_prime_one = M_prime_one,
    B_j = B_j,
    outer_alpha_RE_positions = outer_alpha_RE_positions,
    d_j = d_j, g_j = g_j
  )
  output$internal_parameters$MAVB_parameters <- MAVB_parameters
  output$alpha$var <- variance_by_alpha_jg$variance_jg
  output$alpha$decomp_var <- vi_alpha_decomp
  output$timing <- tic_summary
  class(output) <- "vglmer"
  
  return(output)
}
mgoplerud/vglmer documentation built on Jan. 22, 2025, 6:43 p.m.