R/mixZIHMMConstr.R

Defines functions mixZIHMMConstr FUN

Documented in mixZIHMMConstr

#' Three-state HMM with mixture model for differential peak calling
#'
#' This function fits a three-state HMM with mixture model of Zero-Inflated Negative Binomials (ZINB) and Negative Binomials (NB) for differential peaks across conditions with replicates.
#' The emission distribution of the HMM component associated with background counts follows a ZINB distribution.
#' The emission distribution of the HMM component associated with enrichment counts in consensus follows a NB distribution.
#' Only five parameters are associated with the emission distributions of this three-state HMM: zero-inflation probability, mean and dispersion of background counts, and
#' mean and dispersion of enrichment counts.
#'
#' @param ChIP M*N matrix of ChIP read counts, where M is the number of windows in the analyzed genome and N is the number of conditions*replicates.
#' @param Control M*N matrix of log-transformed Control read counts (not yet implemented)
#' @param offset M*N matrix of offsets. If no offset is used, use offset = matrix(0,nrow=M,ncol=N)
#' @param group vector of length N with condition (numeric) labels (e.g. c(1,1,2,2,3,3) for three conditions and two replicates each)
#' @param control list of control arguments from controlPeaks()
#'
#' @details For ChIP, Control, and offset matrices, columns should be ordered by condition and replicates, respectively. For instance, the first column should have data from
#' the first replicate of the first condition, the second column should have data from the second replicate of the first condition, and so on. The same applies
#' for the elements of group.
#'
#' @return A list with components
#'
#' @author Pedro L. Baldoni, \email{pedrobaldoni@gmail.com}
#' @references \url{https://github.com/plbaldoni/mixHMM}
#'
#' @examples
#' data(H3K36me3)
#' ChIP = SummarizedExperiment::assay(H3K36me3)
#' offset = ZIMHMM::createOffset(ChIP,method = 'ratio')
#' group = c(1,1,1,2,2,2,3,3,3)
#' B = 2^(length(unique(group)))-2
#' \dontrun{output = mixZIHMMConstr(ChIP,Control=NULL,offset,group,B,control = controlPeaks())}
#'
#' @importFrom ZIMHMM HMM
#'
#' @export
mixZIHMMConstr = function(ChIP,Control=NULL,offset,group,control){
    # Creating control elements
    for(i in seq_along(control)){assign(names(control)[i],control[[i]])}
    if(!(length(epsilon.em)==4) & criterion=='MULTI'){stop('For MULTI criterion, the length of error.em must be 4.')}
    if(!is.null(trim.offset)){offset = round(offset,digits = trim.offset)}

    # General parameters
    rare <- 0.05
    ncolControl = 1
    ngroup = length(unique(group))
    namesControl = 'Int'
    M=nrow(ChIP);N=ncol(ChIP);K=3
    error.em = 1
    it.em = 0
    count.em = 0
    modellist = list()
    parlist = list()
    psilist = list()
    displist = list()
    zlist = list()
    dtlist = list()

    # Parameter initializations
    if(!quiet){cat(paste0(c(rep('#',45),'\n')));cat("Algorithm initialization...\n")}
    for(i in unique(group)){modellist[[paste0('model',i)]] <- ZIMHMM::HMM(ChIP.init = rowSums(ChIP[,which(group==i),drop=F]),
                                                                          Control.init = NULL,
                                                                          offset.init = rowMeans(offset[,which(group==i),drop=F]),
                                                                          pcut = pcut)$Viterbi}
    modellist <- data.table::setDT(modellist)
    z.old <- HMM.enumerate(chain = modellist)
    Ref <- z.old$Ref
    rm(modellist)

    # Transforming data into data.table
    DT <- data.table::data.table(ChIP = ChIP,Dsg.Int = 1,offset = offset,PostProb1=1,PostProb2=1,PostProb3=1,z=z.old$Group)
    data.table::setnames(DT,c(paste0('ChIP.',1:N),'Dsg.Int',paste0('offset.',1:N),'PostProb1','PostProb2','PostProb3','z'))
    rm(z.old)

    # Determining number of mixture components and initialization
    if(is.list(pattern)){
        B <- length(pattern)
        z.seq <- lapply(pattern,FUN = function(x){
            aux <- rep(F,ngroup);aux[x]<-T
            which(apply(as.matrix(Ref[,1:ngroup])==1,1,FUN = function(x){all(x==aux)}))
        })
    } else{
        if(pattern=='all'){
            B <- 2^(length(unique(group)))-2
            z.seq <- lapply(1+1:B,FUN = function(x){x})
        } else if(pattern=='cluster'){
            propdiff <- which(DT[,table(z)/.N]>rare)
            z.seq <- lapply(as.numeric(names(propdiff[!names(propdiff)%in%range(data$z)])),FUN = function(x){x})
            B <- length(z.seq)
            if(B==0){stop('Error: differential patterns are rare.')}
        } else{
            stop('Error: pattern should be either "all", "cluster", or a list. Check controlPeaks().')
        }
    }
    z.diff <- DT[(!z%in%c(1,max(z))),z]

    # Stacking DT
    DTvec <- vecData(data = DT,M = M,N = N,B = B,ref = Ref,group = group,ngroup = ngroup,pattern = pattern,rare = rare)

    # Creating Aggragating Variable
    DTvec[,Group := .GRP,by=c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset')]

    # Creating Unique data.table
    DTvec.unique <- unique(DTvec,by='Group')[,c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset','Group'),with=F]

    ## Initial probabilities
    pi1.old = 0.999;pi2.old = (1-pi1.old)/2;pi3.old = (1-pi1.old)/2;pi.old = c(pi1.old,pi2.old,pi3.old)

    ## Model-specific parameters
    ### Aggregating data
    dtlist <- c(list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',1,')'),agg = 'PostProb1')),
                lapply(1:B,FUN = function(i){agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z%in%c',paste0('(',paste(z.seq[[i]],collapse = ','),'))')),agg = 'PostProb2')}),
                list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',max(DT$z),')'),agg = 'PostProb3')))

    ### Calculating MLEs
    tryCatch({assign('psi.old',optim(par = c(0.5,0,1,1,0),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
                                     dt = rbindlist(dtlist,idcol = TRUE),maxid = 2+B)$par)},
             error=function(e){psi.old <<- c(0.5,0,1,1,0)})

    ## Transition probabilities
    # To compute the initial value of gamma (transition probabilities), I use all possible patterns z, regardless if we are interested in a subset of them or not
    gamma.old <- HMM.chain(DT[,0*(z==min(z))+1*(!z%in%c(min(z),max(z)))+2*(z==max(z))],K)

    ## Mixing probability
    delta.old <- unlist(lapply(1:B,FUN = function(i){
        subz.diff <- z.diff[z.diff%in%unlist(z.seq)]
        sum(subz.diff%in%z.seq[[i]]/length(subz.diff))
    }))

    # Putting all together
    theta.old <- c(pi.old,gamma.old,psi.old,delta.old)
    theta.k <- theta.old
    names.theta <- c(paste0('pi',1:K),paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx)),
                     paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int')),
                     paste0('NB.',c('Mean.Int','Disp.Int')),paste0('delta',1:B))
    names(theta.k) <- names.theta

    if(!quiet){cat(paste0("Initialization completed!\n"));cat(paste0(c(rep('#',45),'\n')))}

    # EM algorithm begins
    if(!quiet){cat("EM algorithm begins...\n")}

    while(count.em<maxcount.em & it.em<maxit.em){
        it.em <- it.em+1

        # Updating parameters
        pi.k <- theta.k[paste0('pi',1:K)]
        gamma.k <- matrix(theta.k[paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx))],nrow=K,ncol=K,byrow=F);k=(K+1);for(i in 1:K){for(j in 1:K){assign(paste0('gamma',j,i,'.k'),theta.k[k]);k=k+1}}
        psi1.k <- theta.k[paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int'))]
        psi2.k <- theta.k[paste0('NB.',c('Mean.Int','Disp.Int'))]
        delta.k <- theta.k[paste0('delta',1:B)]

        # Outer EM: E-step
        loglik <- ZIHMM.LLmix2_Constr(dt = DTvec,psi1 = psi1.k,psi2 = psi2.k,delta = delta.k,N = N,M = M)

        ## Forward-Backward probabilities
        logF <- hmm_logF(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)
        logB <- hmm_logB(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)

        ## Posterior probabilities
        DT[,paste0('PostProb',1:K):=as.data.table(check.prob(hmm_P1(logF=logF,logB=logB)))]
        DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))):=as.data.table(check.prob(hmm_P2(logF=logF,logB=logB,logf1=loglik[,1],logf2=loglik[,2],logf3=loglik[,3],gamma=gamma.k)))]

        # Outer EM: M-step
        ## Initial and transition probabilities
        PostProb <- HMM.prob(DT)
        pi.k1 <- PostProb$pi
        gamma.k1 <- PostProb$gamma
        zlist[[it.em]] <- Viterbi(LOGF=loglik,P=pi.k1,GAMMA=gamma.k1)

        ## Model parameters
        ### Inner EM
        count.innerem <- 0
        error.innerem <- 1
        delta.tmp <- delta.k
        psi1.tmp <- psi1.k
        psi2.tmp <- psi2.k
        rejection <- (pcut>0)*ifelse((0.9^it.em)>=pcut,(0.9^it.em),pcut)

        while(count.innerem<maxit.innerem & error.innerem>epsilon.innerem){
            #### Inner EM: E-step
            eta <- ZIHMM.mixprob2_Constr(dt = DTvec,delta = delta.tmp,psi1 = psi1.tmp,psi2 = psi2.tmp,N = N,M = M)

            #### Inner EM: M-step
            ##### Mixing probability
            delta.new <- unlist(lapply(1:B,FUN = function(i){sum(DT[,PostProb2]*eta[,i])/sum(DT[,PostProb2])}));names(delta.new) = names(delta.tmp)

            ##### Model parameters
            dtlist <- c(lapply(1, FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,c('Rejection1') := list(PostProb1)][PostProb1<rejection,Rejection1 := rbinom(.N,1,prob=PostProb1/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,c('Rejection1') := .(rep(DT[,Rejection1],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection1>0)',agg = 'Rejection1')
            }),
            lapply(2:(B+1),FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,weights.tmp := PostProb2*eta[,(i-1)]][,paste0('Rejection2.',(i-1)) := list(weights.tmp)][weights.tmp<rejection,paste0('Rejection2.',(i-1)) := rbinom(.N,1,prob=weights.tmp/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,paste0('Rejection2.',(i-1)) := .(rep(DT[[paste0('Rejection2.',(i-1))]],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(Rejection2.',(i-1),'>0)'),agg = paste0('Rejection2.',(i-1)))
            }),
            lapply(B+2, FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,c('Rejection3') := list(PostProb3)][PostProb3<rejection,Rejection3 := rbinom(.N,1,prob=PostProb3/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,c('Rejection3') := .(rep(DT[,Rejection3],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection3>0)',agg = 'Rejection3')
            }))

            ###### Calculating MLEs
            tryCatch({assign('model',optim(par = c(psi1.tmp,psi2.tmp),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
                                           dt = rbindlist(dtlist,idcol = TRUE),maxid = 2+B))},
                     error=function(e){model<<-list();model[['par']]<<-c(psi1.tmp,psi2.tmp);model[['convergence']]<<-99})

            ###### Saving parameters
            psi1.new <- model$par[1:3]
            psi2.new <- model$par[4:5]

            ##### Checking convergence
            error.innerem <- max(abs((c(delta.new,c(psi1.new,psi2.new))-c(delta.tmp,c(psi1.tmp,psi2.tmp)))/c(delta.tmp,c(psi1.tmp,psi2.tmp))))

            ##### Updating paramaters
            delta.tmp <- delta.new
            psi1.tmp <- psi1.new
            psi2.tmp <- psi2.new
            count.innerem <- count.innerem + 1
        }

        ### Saving parameters
        delta.k1 <- delta.new
        psi.k1 <- c(psi1.new,psi2.new)

        # Updating parameter history
        theta.k1 <- c(pi.k1,gamma.k1,psi.k1,delta.k1);names(theta.k1) = names(theta.k)
        theta.k <- theta.k1
        parlist[[it.em]] <- c(it=it.em,BIC = getBIC(logF = logF,nPar = K*(K-1)+(K-1)+(B-1)+5,group = group),
                              Q=Q(as.matrix(DT[,.(PostProb1,PostProb2,PostProb3)]),as.matrix(DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))),with=F]),loglik,pi.k1,gamma.k1),
                              error=error.em[1],theta.k1,m=model$convergence)

        # Computing EM error
        gap <- ifelse(it.em>minit.em,gap.em,1)
        if(it.em>1){
            parlist.old <- parlist[[(it.em-gap)]][names(psi.k1)]
            parlist.new <- parlist[[it.em]][names(psi.k1)]
            zlist.table <- data.table(old = zlist[[(it.em-gap)]], new = zlist[[it.em]])
            ACC <- 100*zlist.table[,.N,by=.(old,new)][(old==0 & new==0) | (old==1 & new==1) | (old==2 & new==2),sum(N)]/M
        } else{
            parlist.old <- rep(1,length(names(psi.k1)))
            parlist.new <- rep(1,length(names(psi.k1)))
            ACC <- 0
        }

        MRCPE <- max(abs((parlist.new-parlist.old)/parlist.old)) #Max. Abs. Rel. Change. of par. estimates
        MACPE <- max(abs(parlist.new-parlist.old)) #Max. Abs. Change. of par. estimates
        ARCEL <- ifelse(it.em>=2,abs((parlist[[it.em]][['Q']] - parlist[[(it.em-gap)]][['Q']])/parlist[[(it.em-gap)]][['Q']]),0) #Abs. Rel. Change of expected log-likelihood of complete data (Q function)
        MULTI <- c(MRCPE,MACPE,ARCEL,100-ACC)
        error.em <- (it.em>=2)*get(criterion) + (it.em<2)*rep(1,length(get(criterion)))
        count.em <- as.numeric(any(error.em<=epsilon.em))*(it.em>minit.em)*(count.em+1) + 0

        #Outputing history
        if(!quiet){
            cat(paste0(c(rep('#',45),'\n')))
            cat('\rIteration: ',it.em,', Error(s): ',paste(formatC(error.em, format = "e", digits = 2),collapse = ', '),', Viterbi Agreement: ',round(ACC,2),'%.\n',sep='')
            cat("\r",paste0('Q-function: ',round(parlist[[it.em]][['Q']],2),', BIC: ',round(parlist[[it.em]][['BIC']],2)),"\n")
            cat("\r",paste('Max. abs. rel. change of parameter estimates: '),MRCPE,"\n")
            cat("\r",paste('Max. abs. change of parameter estimates: '),MACPE,"\n")
            cat("\r",paste('Abs. rel. change of Q-function: '),ARCEL,"\n")
            cat(paste0(c(rep('#',45),'\n')))
        }
    }

    # Organizing output
    logF <- setnames(as.data.table(logF),c('Background','Differential','Enrichment'))
    logB <- setnames(as.data.table(logB),c('Background','Differential','Enrichment'))
    loglik <- setnames(as.data.table(loglik),c('Background','Differential','Enrichment'))
    eta <- setnames(as.data.table(eta),do.call(paste0,Ref[2:(2+B-1),1:ngroup]))

    if(!quiet){cat('\nDone!\n')}
    return(list('Pi'=pi.k1,'Gamma'=gamma.k1,'Psi'=psi.k1,
                'Prob'=DT[,.(PostProb1,PostProb2,PostProb3)],'LogF'=logF,'LogB'=logB,'Loglik'=loglik,
                'Parhist'=as.data.frame(do.call(rbind,parlist)),'Viterbi'=zlist[[it.em]],'Mix.Prob'=eta))
}
plbaldoni/mixHMM documentation built on Nov. 8, 2019, 8:05 p.m.