R/EMGeneric.R

Defines functions EM_M_zero EM_E_z_zero_lat EM_E_z_zero_obs EM_E_k EM_E_z_lat EM_E_z_obs EMMalpha

Documented in EMMalpha

## M-Step for alpha
#' ECM: M-Step for logit regression coefficients \code{alpha}.
#'
#' @param X A N*P matrix of numerical covariates.
#' @param alpha A g*P matrix of old logit regression coefficients.
#' @param comp.zkz.e.list An object returned by \code{EMEzkz}.
#' @param alpha.iter.max Numeric: maximum number of iterations.
#' @param penalty TRUE/FALSE, which indicates whether penalty is applied.
#' @param hyper.alpha A numeric of penalty applied to \code{alpha}.
#'
#' @return \code{alpha.new} Updated logit regression coefficients.
#'
#' @importFrom matrixStats rowLogSumExps
#'
#' @keywords internal
#'
#' # @export EMMalpha
EMMalpha = function(X, alpha, comp.zkz.e.list, alpha.iter.max,
                    penalty, hyper.alpha)
{
  comp.zpzk = XPlusYColTimesZ(comp.zkz.e.list$z.e.obs, comp.zkz.e.list$z.e.lat, comp.zkz.e.list$k.e)
  # comp.zkz.e.list$z.e.obs + sweep(comp.zkz.e.list$z.e.lat, 1, comp.zkz.e.list$k.e, FUN = "*", check.margin = FALSE)
  
  sample.size.n = nrow(X)
  n.covar.p = ncol(X)
  n.comp = nrow(alpha)
  iter=array(0, dim = c(n.comp, 1))
  alpha.new = alpha
  alpha.old = alpha - Inf
  comp.zpzk.marg = apply(comp.zpzk, 1, sum)
  
  for (j in 1:(n.comp-1)) # The last component's alpha's are always kept at zero (reference category).
  {
    while ((iter[j]<=alpha.iter.max)&(sum((alpha.old[j,]-alpha.new[j,])^2)>10^(-8))) # Stopping criteria: (alpha.iter.max) iterations, or small difference
    {
      alpha.old[j,]=alpha.new[j,]
      gate.body=tcrossprod(X,alpha.new)
      pp = exp(gate.body-rowLogSumExps(gate.body))
      qqj = exp(rowLogSumExps(array(gate.body[,-j],dim=c(sample.size.n,n.comp-1)))-rowLogSumExps(gate.body))
      dQ = EMalphadQ(X, comp.zpzk[,j], comp.zpzk.marg, pp[,j]) - if(penalty){alpha.new[j,]/hyper.alpha^2} else{0}
        # apply(sweep(X,1,comp.zpzk[,j]-comp.zpzk.marg*exp(gate.body[,j]-rowLogSumExps(gate.body)),FUN="*",check.margin=FALSE),2,sum)-if(penalty){alpha.new[j,]/hyper.alpha^2} else{0}
      dQ2 = EMalphadQ2(X, comp.zpzk.marg, pp[,j], qqj) - if(penalty){diag(1/hyper.alpha^2,nrow = n.covar.p, ncol = n.covar.p)} else{diag(10^(-7),nrow = n.covar.p, ncol = n.covar.p)}
        # -crossprod(sweep(X,1,comp.zpzk.marg*exp(rowLogSumExps(array(gate.body[,-j],dim=c(sample.size.n,n.comp-1)))+gate.body[,j]-2*rowLogSumExps(gate.body)),FUN="*",check.margin=FALSE),X)-if(penalty){diag(1/hyper.alpha^2,nrow = n.covar.p, ncol = n.covar.p)} else{diag(10^(-7),nrow = n.covar.p, ncol = n.covar.p)}
      
      alpha.new[j,]=alpha.new[j,] + crossprod(dQ, chol2inv(chol(-dQ2))) # -crossprod(dQ,solve(dQ2))
      iter[j] = iter[j]+1
    }
  }
  
  return(alpha.new)
}


#' @keywords internal
EM_E_z_obs <- function(gate_expert_ll_comp, gate_expert_ll) {
  return(exp(XColMinusY(gate_expert_ll_comp, gate_expert_ll)))
}

#' @keywords internal
EM_E_z_lat <- function(gate_expert_tn_bar_comp, gate_expert_tn_bar) {
  tmp = exp(XColMinusY(gate_expert_tn_bar_comp, gate_expert_tn_bar))
  tmp[is.na(tmp)] = 1/ncol(gate_expert_tn_bar_comp)
  return(tmp)
}

#' @keywords internal
EM_E_k <- function(gate_expert_tn) {
  return(expm1(-gate_expert_tn))
}

#' @keywords internal
EM_E_z_zero_obs <- function(yl, p_old, gate_expert_ll_pos_comp){
  tmp = ifelse(yl==0, p_old/(p_old + (1-p_old)*exp(gate_expert_ll_pos_comp)), 0.0)
  return(tmp)
}

#' @keywords internal
EM_E_z_zero_lat <- function(tl, p_old, gate_expert_tn_bar_pos_comp){
  tmp = ifelse(tl>0, p_old/(p_old + (1-p_old)*exp(gate_expert_tn_bar_pos_comp)), 0.0)
  return(tmp)
}

#' @keywords internal
EM_M_zero <- function(z_zero_e_obs, z_pos_e_obs, z_zero_e_lat, z_pos_e_lat, k_e) {
  num = sum(z_zero_e_obs + (z_zero_e_lat * k_e))
  denom = num + sum(z_pos_e_obs + (z_pos_e_lat * k_e))
  return(num/denom)
}
sparktseung/LRMoE documentation built on March 21, 2022, 3:22 a.m.