R/EM.R

Defines functions scale_m_step m_step.lik_mixture_normal_per_scale m_step L_mixsq.mixture_normal_per_scale L_mixsq.mixture_normal L_mixsq cal_L_mixsq_s_per_scale EM_pi

Documented in L_mixsq L_mixsq.mixture_normal L_mixsq.mixture_normal_per_scale m_step m_step.lik_mixture_normal_per_scale

#
#
#
# @title EM algorithm to select mixture weight in a  Empirical Bayes way
#
# @description Select the mixture weight by maximizing the marginal likelihood
#
# @param G_prior mixture normal prior  or mixture  normal per scale
#
# @param Bhat  matrix pxJ regression coefficient, Bhat[j,t] corresponds to regression coefficient of Y[,t] on X[,j]
#
# @param Shat matrix pxJ standard error, Shat[j,t] corresponds to standard error of the regression coefficient of Y[,t] on X[,j]
#
# @param indx_lst list generated by \code{\link{gen_wavelet_indx}} for the given level of resolution, used only with class mixture_normal_per_scale
#
# @param max_step numeric, maximum number of EM iteration
#
# @param init_pi0_w starting value of weight on null compoenent in mixsqp
#
# @param control_mixsqp list of parameter for mixsqp function see mixsqp package
#
# @param lowc_wc wavelet coefficient with low count to be discarded
#
# @param espsilon numeric, tolerance EM algorithm
# @param  tol_null_prior tolerance for the mixture on the null component if the mass on the point mass is large than 1- tol_null_prior then set BF=1

#  @param nullweight numeric value for penalizing likelihood at point mass 0 (should be between 0 and 1)
# (usefull in small sample size)
#@param indx_lst internal list of wavelet coefficients
#
# @param nullweight penalization parameter
# @return
#\item{tpi_k}{ fitted mixture proportion}
#\item{lBF}{ log Bayes Factor}
#
# @export
#
EM_pi <- function(G_prior,Bhat, Shat, indx_lst,
                  max_step = 100,
                  espsilon = 0.0001,
                  init_pi0_w =1,
                  control_mixsqp,
                  lowc_wc,
                  nullweight,
                  max_SNP_EM=1000,
                  df=NULL,
                  tol_null_prior=0.001
                  ){

  #static parameters

## Deal with overfitted cases
  
  Shat[ is.na(Shat) ] <- 1e-32 #some rare case in overfitting and numerical limitation of Rfast
  Shat[ Shat<=0 ] <- 1e-32
  lBF <- log_BF(G_prior,
                Bhat,Shat,
                indx_lst=indx_lst,
                lowc_wc=lowc_wc,
                df = df)

  if(sum(is.na(lBF))>0){
    lBF[which(is.na(lBF))] <- 3*max(lBF, na.rm=TRUE)# normally due to problem related to
    #too little residual variance
  }

  if( length(lBF)> max_SNP_EM){ # basically allow running EM only on data point with most signal
    idx <- order(lBF, decreasing = TRUE)[1:ceiling(max_SNP_EM)]

  }else{
    idx <- 1:length(lBF)
  }
  Lmat  <-  L_mixsq(G_prior, Bhat[idx,], Shat[idx,], indx_lst)
  J <- dim(Bhat)[1]
  tsd_k <- get_sd_G_prior(G_prior)

  #dynamic parameters
  tpi_k = get_pi_G_prior(G_prior)
  oldloglik <-0
  newloglik <-1

  zeta <- rep(1/J,J) #assignation initial value
  k <- 1 #counting the number of iteration

  while( k <=max_step &  abs(newloglik-oldloglik)>=espsilon)
  {
    # E step----
    oldloglik <- cal_lik(lBF,zeta)
    zeta      <- cal_zeta(lBF)

    # M step ----
    tpi_k   <- m_step(Lmat,zeta= zeta[idx],
                      indx_lst,
                      init_pi0_w     = init_pi0_w,
                      control_mixsqp = control_mixsqp,
                      nullweight     = nullweight,
                      tol_null_prior=tol_null_prior)
    G_prior <- update_prior(G_prior,tpi_k)

    lBF <-  log_BF(G_prior,
                   Bhat,
                   Shat,
                   indx_lst=indx_lst,
                   lowc_wc=lowc_wc,
                   df = df)

    newloglik <- cal_lik(lBF,zeta)
    k <- k+1

  }

  out <- list(tpi_k = tpi_k,lBF = lBF)
  class(out) <- c("EM_pi","list")
  return(out)
}



# @title Subroutine to compute likelihood matrix at scale s for mixsqp under mixture normal per scale prior
#
# @description Add description here.
#
# @param G_prior mixture normal prior
#
# @param s scale where the likelihood matrix should be computed
#
# @param Bhat  matrix pxJ regression coefficient, Bhat[j,t] corresponds to regression coefficient of Y[,t] on X[,j]
#
# @param Shat matrix pxJ standard error, Shat[j,t] corresponds to standard error of the regression coefficient of Y[,t] on X[,j]
#
# @param indx_lst list generated by \code{\link{gen_wavelet_indx}} for the given level of resolution, used only with class mixture_normal_per_scale
#
# @return L see L argument mixsqp package mixsqp function
#
#
# @export
#
#' @importFrom stats dnorm
cal_L_mixsq_s_per_scale <- function(G_prior,s, Bhat, Shat ,indx_lst,is.EBmvFR=FALSE)
{
  m <-  (G_prior[[s]])
   sdmat <-sqrt(outer(c(Shat[,indx_lst[[s]]]^2),
                      get_sd_G_prior(G_prior)[[s]]^2,"+"))




  L = (dnorm(
    outer(
      c(
        Bhat[,indx_lst[[s]]]),
      m$fitted_g$mean,FUN="-")/sdmat,
    log=TRUE) -log(sdmat )
  )
  #dealing in case of due to small sd due to small sample size
  L     <-  apply(L, 2, function(x){
    x[which(is.na(x))] <- median(x, na.rm=T)
    return(x)
  })
  L = exp(L)
  if(!is.EBmvFR){
    L <- rbind(c(1, rep( 0,(ncol(L)-1)  )),#adding penalty line
               L)
  }

  return(L)
}

#'@title Compute likelihood matrix for mixsqp
#'
#' @description Compute likelihood matrix for mixsqp
#'
#' @param G_prior mixture normal prior
#'
#' @param Bhat  matrix pxJ regression coefficient, Bhat[j,t] corresponds to regression coefficient of Y[,t] on X[,j]
#'
#' @param Shat matrix pxJ standard error, Shat[j,t] corresponds to standard error of the regression coefficient of Y[,t] on X[,j]
#'
#' @param indx_lst list generated by \code{\link{gen_wavelet_indx}} for the given level of resolution, used only with class  mixture_normal_per_scale
#' @param \dots Other arguments.
#' @return See L argument mixsqp package mixsqp function
#'
#' @export
#' @keywords internal
#'
L_mixsq <- function(G_prior,Bhat, Shat, indx_lst,...)
  UseMethod("L_mixsq")

#' @rdname L_mixsq
#'
#' @method L_mixsq mixture_normal
#'
#' @export L_mixsq.mixture_normal
#'
#' @export
#' @keywords internal
#'
L_mixsq.mixture_normal <- function(G_prior,
                                   Bhat,
                                   Shat,
                                   indx_lst,
                                   is.EBmvFR=FALSE,...)
{
  m     <-  (G_prior[[1]])
  sdmat <- sqrt(outer(c(Shat ^2),get_sd_G_prior(G_prior)^2,"+"))
  L     <- (
    dnorm(
      outer(
        c(Bhat),
        rep(0,length(get_sd_G_prior(G_prior))),
        FUN="-"
      )/sdmat,
      log=TRUE
    ) -log(sdmat )
  )
  #dealing in case of due to small sd due to small sample size
  L     <-  apply(L, 2, function(x){
    x[which(is.na(x))] <- median(x, na.rm=T)
    return(x)
  })
  L= exp(L)
  if( !is.EBmvFR ){
    L <- rbind(c(1, rep( 0,(ncol(L)-1)  )),#adding penalty line
               L)
  }

  class(L) <- "lik_mixture_normal"
  return(L)
}


#' @rdname L_mixsq
#'
#' @method L_mixsq mixture_normal_per_scale
#'
#' @export L_mixsq.mixture_normal_per_scale
#'
#' @export
#' @keywords internal
#'
L_mixsq.mixture_normal_per_scale <- function(G_prior,
                                             Bhat,
                                             Shat,
                                             indx_lst,
                                             is.EBmvFR=FALSE,...)
{
  L  <- lapply(1:length(indx_lst  ) , function(s) cal_L_mixsq_s_per_scale (G_prior,s, Bhat, Shat, indx_lst,is.EBmvFR=is.EBmvFR))

  class(L) <- c("lik_mixture_normal_per_scale","list")
  return(L)
}

#' @title Compute M step in the weighted ash problem for different prior
#'
#' @description Compute M step in the weighted ash problem for different prior
#'
#' @param L output of  L_mixsqp  function
#'
#' @param zeta assignment probabilities for each covariate
#'
#' @param indx_lst list generated by \code{\link{gen_wavelet_indx}} for the given level of resolution, used only with class  mixture_normal_per_scale
#' @param init_pi0_w starting value of weight on null compoenent in mixsqp
#' @param control_mixsqp list of parameter for mixsqp function see mixsqp package
#' @param nullweight numeric value for penalizing likelihood at point mass 0 (should be between 0 and 1)
#' (usefull in small sample size)
#'  @param  tol_null_prior tolerance for the mixture on the null component if the mass on the point mass is large than 1- tol_null_prior then set BF=1

#' @param \dots Other arguments.
#' @return a vector of proportion (class pi_mixture_normal)
#'
#' @export
#' @keywords internal

m_step <- function(L, zeta, indx_lst,init_pi0_w,control_mixsqp,nullweight,is.EBmvFR=FALSE, tol_null_prior=0.001,...)
  UseMethod("m_step")


#' @rdname m_step
#'
#' @importFrom mixsqp mixsqp
#'
#' @method m_step lik_mixture_normal
#'
#' @export m_step.lik_mixture_normal
#'
#' @importFrom mixsqp mixsqp
#'
#' @export
#' @keywords internal
#'
m_step.lik_mixture_normal <- function (L,
                                       zeta,
                                       indx_lst,
                                       init_pi0_w ,
                                       control_mixsqp,
                                       nullweight,
                                       is.EBmvFR=FALSE,
                                       tol_null_prior=0.001,
                                       ...)
{

  if(!is.EBmvFR){
    w <- c(nullweight*sum(lengths(indx_lst)),
           rep(zeta,sum(lengths(indx_lst))) # setting the weight to fit the weighted ash problem
    )

  }else{
    w <-  rep(zeta,sum(lengths(indx_lst)))  # setting the weight to fit the weighted ash problem


  }


  tlength <- ncol(L) - 1

  mixsqp_out <- mixsqp::mixsqp(L,
                               w,
                               log = FALSE,
                               x0  = c(init_pi0_w ,rep(1e-6,tlength)), # put starting point close to sparse solution
                               control = control_mixsqp
  )
  out <- mixsqp_out$x
  if(out[1]>1-tol_null_prior){
    out=0*out
    out[1]=1
  }
  class(out) <-  "pi_mixture_normal"
  return(out)
}

#' @rdname m_step
#'
#' @method m_step lik_mixture_normal_per_scale
#'
#' @export m_step.lik_mixture_normal_per_scale
#'
#' @export
#' @keywords internal
#'
m_step.lik_mixture_normal_per_scale <- function(L,
                                                zeta,
                                                indx_lst,
                                                init_pi0_w=1,
                                                control_mixsqp,
                                                nullweight,
                                                is.EBmvFR=FALSE,
                                                tol_null_prior=0.001,
                                                ...)
{
  #setting the weight to fit the weighted ash problem


  out <- lapply(1:length(indx_lst) ,
                function(s) scale_m_step(L,s,zeta,indx_lst,
                                         init_pi0_w     =init_pi0_w,
                                         control_mixsqp = control_mixsqp,
                                         nullweight     =  nullweight,
                                         is.EBmvFR      = is.EBmvFR,
                                         tol_null_prior=tol_null_prior)
  )
  class( out ) <-  c("pi_mixture_normal_per_scale" )
  return(out)

}

#@title Subroutine to compute M step in the weighted ash problem for normal mixture prior per scale at a given scale s
#
# @description  Subroutine to compute M step in the weighted ash
#
# @param L output of the L_mixsqp.mixture_normal_per_scale function
#
# @param s scale
#
# @param zeta assignment probabilities for each covariate
#
# @param indx_lst list generated by \code{\link{gen_wavelet_indx}} for the given level of resolution, used only with class  mixture_normal_per_scale
# @param init_pi0_w starting value of weight on null compoenent in mixsqp
# @param control_mixsqp list of parameter for mixsqp function see mixsqp package
# @param nullweight penalization parameter
# @return a vector of proportion for the scale s
#
# @importFrom mixsqp mixsqp
#
# @export
scale_m_step <- function(L,
                         s,
                         zeta,
                         indx_lst,
                         init_pi0_w=0.5,
                         control_mixsqp,
                         nullweight,
                         is.EBmvFR=FALSE,
                         tol_null_prior=0.001,
                         ...)
{

  if(!is.EBmvFR){
    w <-  c(nullweight*length(indx_lst[[s]] ),
            rep(zeta,length(indx_lst[[s]] ))
            )
  }else{
    w <-  rep(zeta,length(indx_lst[[s]] )
               )
  }



  tlength <- dim(L[[s]])[2]-1

  mixsqp_out <- mixsqp::mixsqp( L[[s]] ,
                                w,
                                x0 = c(init_pi0_w, rep(1e-6,  tlength )),
                                log=FALSE ,
                                control = control_mixsqp
  )

  out <- mixsqp_out$x
  if(out[1]>1-tol_null_prior){
    out=0*out
    out[1]=1
  }
  return( out)

}
stephenslab/susiF.alpha documentation built on March 1, 2025, 4:28 p.m.