R/mixZINBHMM.R

Defines functions mixZINBHMM

Documented in mixZINBHMM

#' Three-state HMM with mixture model for differential peak calling using zero-inflated negative binomial distributions
#'
#' This function fits a three-state HMM with mixture model of Zero-Inflated Negative Binomials for differential peak detection across conditions with replicates.
#'
#' @param object a mixNBHMMDataSet
#' @param control list of control arguments from controlEM()
#'
#' @return
#'
#' The mixNBHMMDataSet with the following list as \code{metadata}:
#' \describe{
#'   \item{pi}{Vector of estimated initial probabilities of the HMM}
#'   \item{gamma}{Vector of estimated transition probabilities of the HMM}
#'   \item{psi}{Vector of estimated means and dispersion parameters pertaining to the emission distributions}
#'   \item{prob}{Data table of window-based (rows) posterior probabilities of the HMM emission distributions (columns)}
#'   \item{mixProb}{Data table of window-based (rows) posterior probabilities of the mixture model from each mixture component (columns)}
#'   \item{viterbi}{Vector of Viterbi sequence of hidden states}
#'   \item{logF}{Data table of window-based (rows) log forward probabilities from the emission distributions (columns)}
#'   \item{logB}{Data table of window-based (rows) log backward probabilities from the emission distributions (columns)}
#'   \item{logLik}{Data table of window-based (rows) log probabilities from the emission distributions (columns)}
#'   \item{parHist}{Data table of parameter estimates (column) from every EM iteration (rows)}
#' }
#' IMPORTANT: the output mixNBHMMDataSet has conditions and replicates reordered. Make sure to check \code{colData()} of your output.
#'
#' @author Pedro L. Baldoni, \email{pedrobaldoni@gmail.com}
#' @references \url{https://github.com/plbaldoni/mixNBHMM}
#'
#' @examples
#' data(ENCODE)
#' ENCODE <- createOffset(object = ENCODE,type = 'loess',span = 1)
#' \dontrun{ENCODE <- mixZINBHMM(object = ENCODE,control = controlEM())}
#'
#' @importFrom ZIMHMM HMM
#' @import data.table
#' @rawNamespace import(SummarizedExperiment, except = shift)
#'
#' @export
mixZINBHMM = function(object,control){
    epsilon.em = criterion = Condition = Replicate = trim.offset = pcut = pattern = z = data = Group = min.zero = max.phi = NULL
    Var1 = Var2 = maxcount.em = maxit.em = maxit.innerem = epsilon.innerem = PostProb1 = PostProb2 = PostProb3 = Rejection1  = Rejection3 = weights.tmp = NULL
    minit.em = gap.em = old = new = quiet = NULL

    # Creating control elements
    for(i in seq_along(control)){assign(names(control)[i],control[[i]])}
    if(!(length(epsilon.em)==4) & criterion=='MULTI'){stop('For MULTI criterion, the length of error.em must be 4.')}

    # Sorting the object
    cat(paste0(c(rep('#',45),'\n')));cat("Setting up the data...\n")
    object <- object[,with(colData(object),order(Condition,Replicate))]

    # Extracting elements from object
    groupfactor <- factor(colData(object)$Condition,levels = unique(colData(object)$Condition))
    group <- as.numeric(groupfactor)
    cat("Ordered conditions:",levels(groupfactor),"\n")

    ChIP <- SummarizedExperiment::assay(object,'counts')
    if('offset'%in%SummarizedExperiment::assayNames(object)){
        offset <- SummarizedExperiment::assay(object,'offset')
    } else{
        offset <- matrix(0,nrow = nrow(ChIP),ncol = ncol(ChIP))
    }
    if(!is.null(trim.offset)){offset = round(offset,digits = trim.offset)}

    # General parameters
    rare <- 0.05
    ncolControl = 1
    ngroup = length(unique(group))
    namesControl = 'Int'
    M=nrow(ChIP);N=ncol(ChIP);K=3
    error.em = 1
    it.em = 0
    count.em = 0
    modellist = list()
    parlist = list()
    psilist = list()
    displist = list()
    zlist = list()
    dtlist = list()

    # Parameter initializations
    cat(paste0(c(rep('#',45),'\n')));cat("Algorithm initialization...\n")
    for(i in unique(group)){modellist[[paste0('model',i)]] <- ZIMHMM::HMM(ChIP.init = rowSums(ChIP[,which(group==i),drop=F]),
                                                                          Control.init = NULL,
                                                                          offset.init = rowMeans(offset[,which(group==i),drop=F]),
                                                                          pcut = pcut)$Viterbi}
    modellist <- data.table::setDT(modellist)
    z.old <- HMM.enumerate(chain = modellist)
    Ref <- z.old$Ref
    rm(modellist)

    # Transforming data into data.table
    DT <- data.table::data.table(ChIP = ChIP,Dsg.Int = 1,offset = offset,PostProb1=1,PostProb2=1,PostProb3=1,z=z.old$Group)
    data.table::setnames(DT,c(paste0('ChIP.',1:N),'Dsg.Int',paste0('offset.',1:N),'PostProb1','PostProb2','PostProb3','z'))
    rm(z.old)

    # Determining number of mixture components and initialization
    if(is.list(pattern)){
        B <- length(pattern)
        z.seq <- lapply(pattern,FUN = function(x){
            aux <- rep(F,ngroup);aux[x]<-T
            which(apply(as.matrix(Ref[,1:ngroup])==1,1,FUN = function(x){all(x==aux)}))
        })
    } else{
        if(pattern=='all'){
            B <- 2^(length(unique(group)))-2
            z.seq <- lapply(1+1:B,FUN = function(x){x})
        } else if(pattern=='cluster'){
            propdiff <- which(DT[,table(z)/.N]>rare)
            z.seq <- lapply(as.numeric(names(propdiff[!names(propdiff)%in%range(data$z)])),FUN = function(x){x})
            B <- length(z.seq)
            if(B==0){stop('Error: differential patterns are rare.')}
        } else{
            stop('Error: pattern should be either "all", "cluster", or a list. Check controlPeaks().')
        }
    }
    z.diff <- DT[(!z%in%c(1,max(z))),z]

    # Stacking DT
    DTvec <- vecData(data = DT,M = M,N = N,B = B,ref = Ref,group = group,ngroup = ngroup,pattern = pattern,rare = rare)

    # Creating Aggragating Variable
    DTvec[,Group := .GRP,by=c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset')]

    # Creating Unique data.table
    DTvec.unique <- unique(DTvec,by='Group')[,c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset','Group'),with=F]

    ## Initial probabilities
    pi1.old = 0.999;pi2.old = (1-pi1.old)/2;pi3.old = (1-pi1.old)/2;pi.old = c(pi1.old,pi2.old,pi3.old)

    ## Model-specific parameters
    ### Aggregating data
    dtlist <- c(list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',1,')'),agg = 'PostProb1')),
                lapply(1:B,FUN = function(i){agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z%in%c',paste0('(',paste(z.seq[[i]],collapse = ','),'))')),agg = 'PostProb2')}),
                list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',max(DT$z),')'),agg = 'PostProb3')))

    ### Calculating MLEs
    tryCatch({assign('psi.old',stats::optim(par = c(0.5,0,1,1,0),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
                                     dt = rbindlist(dtlist,idcol = 'id'),maxid = 2+B)$par)},
             error=function(e){psi.old <<- c(0.5,0,1,1,0)})

    ## Transition probabilities
    # To compute the initial value of gamma (transition probabilities), I use all possible patterns z, regardless if we are interested in a subset of them or not
    gamma.old <- HMM.chain(DT[,0*(z==min(z))+1*(!z%in%c(min(z),max(z)))+2*(z==max(z))],K)

    ## Mixing probability
    delta.old <- unlist(lapply(1:B,FUN = function(i){
        subz.diff <- z.diff[z.diff%in%unlist(z.seq)]
        sum(subz.diff%in%z.seq[[i]]/length(subz.diff))
    }))

    # Putting all together
    theta.old <- c(pi.old,gamma.old,psi.old,delta.old)
    theta.k <- theta.old
    names.theta <- c(paste0('pi',1:K),paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx)),
                     paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int')),
                     paste0('NB.',c('Mean.Int','Disp.Int')),paste0('delta',1:B))
    names(theta.k) <- names.theta

   cat(paste0("Initialization completed!\n"));cat(paste0(c(rep('#',45),'\n')))

    # EM algorithm begins
    cat("The EM algorithm begins...\n")

    while(count.em<maxcount.em & it.em<maxit.em){
        it.em <- it.em+1

        # Updating parameters
        pi.k <- theta.k[paste0('pi',1:K)]
        gamma.k <- matrix(theta.k[paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx))],nrow=K,ncol=K,byrow=F);k=(K+1);for(i in 1:K){for(j in 1:K){assign(paste0('gamma',j,i,'.k'),theta.k[k]);k=k+1}}
        psi1.k <- theta.k[paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int'))]
        psi2.k <- theta.k[paste0('NB.',c('Mean.Int','Disp.Int'))]
        delta.k <- theta.k[paste0('delta',1:B)]

        # Outer EM: E-step
        loglik <- ZIHMM.LLmix2_Constr(dt = DTvec,psi1 = psi1.k,psi2 = psi2.k,delta = delta.k,N = N,M = M)

        ## Forward-Backward probabilities
        logF <- hmm_logF(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)
        logB <- hmm_logB(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)

        ## Posterior probabilities
        DT[,paste0('PostProb',1:K):=as.data.table(check.prob(hmm_P1(logF=logF,logB=logB)))]
        DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))):=as.data.table(check.prob(hmm_P2(logF=logF,logB=logB,logf1=loglik[,1],logf2=loglik[,2],logf3=loglik[,3],gamma=gamma.k)))]

        # Outer EM: M-step
        ## Initial and transition probabilities
        PostProb <- HMM.prob(DT)
        pi.k1 <- PostProb$pi
        gamma.k1 <- PostProb$gamma
        zlist[[it.em]] <- Viterbi(LOGF=loglik,P=pi.k1,GAMMA=gamma.k1)

        ## Model parameters
        ### Inner EM
        count.innerem <- 0
        error.innerem <- 1
        delta.tmp <- delta.k
        psi1.tmp <- psi1.k
        psi2.tmp <- psi2.k
        rejection <- (pcut>0)*ifelse((0.9^it.em)>=pcut,(0.9^it.em),pcut)

        while(count.innerem<maxit.innerem & error.innerem>epsilon.innerem){
            #### Inner EM: E-step
            eta <- ZIHMM.mixprob2_Constr(dt = DTvec,delta = delta.tmp,psi1 = psi1.tmp,psi2 = psi2.tmp,N = N,M = M)

            #### Inner EM: M-step
            ##### Mixing probability
            delta.new <- unlist(lapply(1:B,FUN = function(i){sum(DT[,PostProb2]*eta[,i])/sum(DT[,PostProb2])}));names(delta.new) = names(delta.tmp)

            ##### Model parameters
            dtlist <- c(lapply(1, FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,c('Rejection1') := list(PostProb1)][PostProb1<rejection,Rejection1 := stats::rbinom(.N,1,prob=PostProb1/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,c('Rejection1') := list(rep(DT[,Rejection1],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection1>0)',agg = 'Rejection1')
            }),
            lapply(2:(B+1),FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,weights.tmp := PostProb2*eta[,(i-1)]][,paste0('Rejection2.',(i-1)) := list(weights.tmp)][weights.tmp<rejection,paste0('Rejection2.',(i-1)) := stats::rbinom(.N,1,prob=weights.tmp/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,paste0('Rejection2.',(i-1)) := list(rep(DT[[paste0('Rejection2.',(i-1))]],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(Rejection2.',(i-1),'>0)'),agg = paste0('Rejection2.',(i-1)))
            }),
            lapply(B+2, FUN = function(i){
                ###### Updating posterior probabilities with rejection-controlled EM
                DT[,c('Rejection3') := list(PostProb3)][PostProb3<rejection,Rejection3 := stats::rbinom(.N,1,prob=PostProb3/rejection)*rejection]
                ###### Updating the vectorized dataset
                DTvec[,c('Rejection3') := list(rep(DT[,Rejection3],N))]
                ###### Aggregating data
                agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection3>0)',agg = 'Rejection3')
            }))

            ###### Calculating MLEs
            tryCatch({assign('model',stats::optim(par = c(psi1.tmp,psi2.tmp),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
                                           dt = rbindlist(dtlist,idcol = 'id'),maxid = 2+B))},
                     error=function(e){model<<-list();model[['par']]<<-c(psi1.tmp,psi2.tmp);model[['convergence']]<<-99})

            ###### Saving parameters
            psi1.new <- model$par[1:3]
            psi2.new <- model$par[4:5]

            ##### Checking convergence
            error.innerem <- max(abs((c(delta.new,c(psi1.new,psi2.new))-c(delta.tmp,c(psi1.tmp,psi2.tmp)))/c(delta.tmp,c(psi1.tmp,psi2.tmp))))

            ##### Updating paramaters
            delta.tmp <- delta.new
            psi1.tmp <- psi1.new
            psi2.tmp <- psi2.new
            count.innerem <- count.innerem + 1
        }

        ### Saving parameters
        delta.k1 <- delta.new
        psi.k1 <- c(psi1.new,psi2.new)

        # Updating parameter history
        theta.k1 <- c(pi.k1,gamma.k1,psi.k1,delta.k1);names(theta.k1) = names(theta.k)
        theta.k <- theta.k1
        parlist[[it.em]] <- c(it=it.em,BIC = getBIC(logF = logF,nPar = K*(K-1)+(K-1)+(B-1)+5,group = group),
                              Q=Q(as.matrix(DT[,list(PostProb1,PostProb2,PostProb3)]),as.matrix(DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))),with=F]),loglik,pi.k1,gamma.k1),
                              error=error.em[1],theta.k1,m=model$convergence)

        # Computing EM error
        gap <- ifelse(it.em>minit.em,gap.em,1)
        if(it.em>1){
            parlist.old <- parlist[[(it.em-gap)]][names(psi.k1)]
            parlist.new <- parlist[[it.em]][names(psi.k1)]
            zlist.table <- data.table(old = zlist[[(it.em-gap)]], new = zlist[[it.em]])
            ACC <- 100*zlist.table[,.N,by=list(old,new)][(old==0 & new==0) | (old==1 & new==1) | (old==2 & new==2),sum(N)]/M
        } else{
            parlist.old <- rep(1,length(names(psi.k1)))
            parlist.new <- rep(1,length(names(psi.k1)))
            ACC <- 0
        }

        MRCPE <- max(abs((parlist.new-parlist.old)/parlist.old)) #Max. Abs. Rel. Change. of par. estimates
        MACPE <- max(abs(parlist.new-parlist.old)) #Max. Abs. Change. of par. estimates
        ARCEL <- ifelse(it.em>=2,abs((parlist[[it.em]][['Q']] - parlist[[(it.em-gap)]][['Q']])/parlist[[(it.em-gap)]][['Q']]),0) #Abs. Rel. Change of expected log-likelihood of complete data (Q function)
        MULTI <- c(MRCPE,MACPE,ARCEL,100-ACC)
        error.em <- (it.em>=2)*get(criterion) + (it.em<2)*rep(1,length(get(criterion)))
        count.em <- as.numeric(any(error.em<=epsilon.em))*(it.em>minit.em)*(count.em+1) + 0

        #Outputing history
        if(!quiet){
            cat(paste0(c(rep('#',45),'\n')))
            cat('\rIteration: ',it.em,', Error(s): ',paste(formatC(error.em, format = "e", digits = 2),collapse = ', '),', Viterbi Agreement: ',round(ACC,2),'%.\n',sep='')
            cat("\r",paste0('Q-function: ',round(parlist[[it.em]][['Q']],2),', BIC: ',round(parlist[[it.em]][['BIC']],2)),"\n")
            cat("\r",paste('Max. abs. rel. change of parameter estimates: '),MRCPE,"\n")
            cat("\r",paste('Max. abs. change of parameter estimates: '),MACPE,"\n")
            cat("\r",paste('Abs. rel. change of Q-function: '),ARCEL,"\n")
            cat(paste0(c(rep('#',45),'\n')))
        }
    }

    # Organizing output
    logF <- setnames(as.data.table(logF),c('Background','Differential','Enrichment'))
    logB <- setnames(as.data.table(logB),c('Background','Differential','Enrichment'))
    loglik <- setnames(as.data.table(loglik),c('Background','Differential','Enrichment'))
    eta <- setnames(as.data.table(eta),do.call(paste0,Ref[2:(2+B-1),1:ngroup]))
    prob <- setnames(DT[,list(PostProb1,PostProb2,PostProb3)],c('Background','Differential','Enrichment'))
    viterbi <- ifelse(zlist[[it.em]]==0,'B',ifelse(zlist[[it.em]]==1,'D','E'))

    colnames(eta) <- gsub('1','E',gsub('0','B',colnames(eta)))

    metadata(object) <- list('pi'=pi.k1,
                             'gamma'=gamma.k1,
                             'psi'=psi.k1,
                             'prob'=prob,
                             'mixProb'=eta,
                             'viterbi'= viterbi,
                             'logF'=logF,
                             'logB'=logB,
                             'logLik'=loglik,
                             'parHist'=as.data.table(do.call(rbind,parlist)))

    cat('Done!\n');cat(paste0(c(rep('#',45),'\n')))
    return(object)
}
plbaldoni/mixNBHMM documentation built on Dec. 24, 2019, 1:31 p.m.