#' Three-state HMM with mixture model for differential peak calling using zero-inflated negative binomial distributions
#'
#' This function fits a three-state HMM with mixture model of Zero-Inflated Negative Binomials for differential peak detection across conditions with replicates.
#'
#' @param object a mixNBHMMDataSet
#' @param control list of control arguments from controlEM()
#'
#' @return
#'
#' The mixNBHMMDataSet with the following list as \code{metadata}:
#' \describe{
#' \item{pi}{Vector of estimated initial probabilities of the HMM}
#' \item{gamma}{Vector of estimated transition probabilities of the HMM}
#' \item{psi}{Vector of estimated means and dispersion parameters pertaining to the emission distributions}
#' \item{prob}{Data table of window-based (rows) posterior probabilities of the HMM emission distributions (columns)}
#' \item{mixProb}{Data table of window-based (rows) posterior probabilities of the mixture model from each mixture component (columns)}
#' \item{viterbi}{Vector of Viterbi sequence of hidden states}
#' \item{logF}{Data table of window-based (rows) log forward probabilities from the emission distributions (columns)}
#' \item{logB}{Data table of window-based (rows) log backward probabilities from the emission distributions (columns)}
#' \item{logLik}{Data table of window-based (rows) log probabilities from the emission distributions (columns)}
#' \item{parHist}{Data table of parameter estimates (column) from every EM iteration (rows)}
#' }
#' IMPORTANT: the output mixNBHMMDataSet has conditions and replicates reordered. Make sure to check \code{colData()} of your output.
#'
#' @author Pedro L. Baldoni, \email{pedrobaldoni@gmail.com}
#' @references \url{https://github.com/plbaldoni/mixNBHMM}
#'
#' @examples
#' data(ENCODE)
#' ENCODE <- createOffset(object = ENCODE,type = 'loess',span = 1)
#' \dontrun{ENCODE <- mixZINBHMM(object = ENCODE,control = controlEM())}
#'
#' @importFrom ZIMHMM HMM
#' @import data.table
#' @rawNamespace import(SummarizedExperiment, except = shift)
#'
#' @export
mixZINBHMM = function(object,control){
epsilon.em = criterion = Condition = Replicate = trim.offset = pcut = pattern = z = data = Group = min.zero = max.phi = NULL
Var1 = Var2 = maxcount.em = maxit.em = maxit.innerem = epsilon.innerem = PostProb1 = PostProb2 = PostProb3 = Rejection1 = Rejection3 = weights.tmp = NULL
minit.em = gap.em = old = new = quiet = NULL
# Creating control elements
for(i in seq_along(control)){assign(names(control)[i],control[[i]])}
if(!(length(epsilon.em)==4) & criterion=='MULTI'){stop('For MULTI criterion, the length of error.em must be 4.')}
# Sorting the object
cat(paste0(c(rep('#',45),'\n')));cat("Setting up the data...\n")
object <- object[,with(colData(object),order(Condition,Replicate))]
# Extracting elements from object
groupfactor <- factor(colData(object)$Condition,levels = unique(colData(object)$Condition))
group <- as.numeric(groupfactor)
cat("Ordered conditions:",levels(groupfactor),"\n")
ChIP <- SummarizedExperiment::assay(object,'counts')
if('offset'%in%SummarizedExperiment::assayNames(object)){
offset <- SummarizedExperiment::assay(object,'offset')
} else{
offset <- matrix(0,nrow = nrow(ChIP),ncol = ncol(ChIP))
}
if(!is.null(trim.offset)){offset = round(offset,digits = trim.offset)}
# General parameters
rare <- 0.05
ncolControl = 1
ngroup = length(unique(group))
namesControl = 'Int'
M=nrow(ChIP);N=ncol(ChIP);K=3
error.em = 1
it.em = 0
count.em = 0
modellist = list()
parlist = list()
psilist = list()
displist = list()
zlist = list()
dtlist = list()
# Parameter initializations
cat(paste0(c(rep('#',45),'\n')));cat("Algorithm initialization...\n")
for(i in unique(group)){modellist[[paste0('model',i)]] <- ZIMHMM::HMM(ChIP.init = rowSums(ChIP[,which(group==i),drop=F]),
Control.init = NULL,
offset.init = rowMeans(offset[,which(group==i),drop=F]),
pcut = pcut)$Viterbi}
modellist <- data.table::setDT(modellist)
z.old <- HMM.enumerate(chain = modellist)
Ref <- z.old$Ref
rm(modellist)
# Transforming data into data.table
DT <- data.table::data.table(ChIP = ChIP,Dsg.Int = 1,offset = offset,PostProb1=1,PostProb2=1,PostProb3=1,z=z.old$Group)
data.table::setnames(DT,c(paste0('ChIP.',1:N),'Dsg.Int',paste0('offset.',1:N),'PostProb1','PostProb2','PostProb3','z'))
rm(z.old)
# Determining number of mixture components and initialization
if(is.list(pattern)){
B <- length(pattern)
z.seq <- lapply(pattern,FUN = function(x){
aux <- rep(F,ngroup);aux[x]<-T
which(apply(as.matrix(Ref[,1:ngroup])==1,1,FUN = function(x){all(x==aux)}))
})
} else{
if(pattern=='all'){
B <- 2^(length(unique(group)))-2
z.seq <- lapply(1+1:B,FUN = function(x){x})
} else if(pattern=='cluster'){
propdiff <- which(DT[,table(z)/.N]>rare)
z.seq <- lapply(as.numeric(names(propdiff[!names(propdiff)%in%range(data$z)])),FUN = function(x){x})
B <- length(z.seq)
if(B==0){stop('Error: differential patterns are rare.')}
} else{
stop('Error: pattern should be either "all", "cluster", or a list. Check controlPeaks().')
}
}
z.diff <- DT[(!z%in%c(1,max(z))),z]
# Stacking DT
DTvec <- vecData(data = DT,M = M,N = N,B = B,ref = Ref,group = group,ngroup = ngroup,pattern = pattern,rare = rare)
# Creating Aggragating Variable
DTvec[,Group := .GRP,by=c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset')]
# Creating Unique data.table
DTvec.unique <- unique(DTvec,by='Group')[,c('ChIP','Dsg.Int',paste0('Dsg.Mix',1:B),'offset','Group'),with=F]
## Initial probabilities
pi1.old = 0.999;pi2.old = (1-pi1.old)/2;pi3.old = (1-pi1.old)/2;pi.old = c(pi1.old,pi2.old,pi3.old)
## Model-specific parameters
### Aggregating data
dtlist <- c(list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',1,')'),agg = 'PostProb1')),
lapply(1:B,FUN = function(i){agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z%in%c',paste0('(',paste(z.seq[[i]],collapse = ','),'))')),agg = 'PostProb2')}),
list(agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(z==',max(DT$z),')'),agg = 'PostProb3')))
### Calculating MLEs
tryCatch({assign('psi.old',stats::optim(par = c(0.5,0,1,1,0),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
dt = rbindlist(dtlist,idcol = 'id'),maxid = 2+B)$par)},
error=function(e){psi.old <<- c(0.5,0,1,1,0)})
## Transition probabilities
# To compute the initial value of gamma (transition probabilities), I use all possible patterns z, regardless if we are interested in a subset of them or not
gamma.old <- HMM.chain(DT[,0*(z==min(z))+1*(!z%in%c(min(z),max(z)))+2*(z==max(z))],K)
## Mixing probability
delta.old <- unlist(lapply(1:B,FUN = function(i){
subz.diff <- z.diff[z.diff%in%unlist(z.seq)]
sum(subz.diff%in%z.seq[[i]]/length(subz.diff))
}))
# Putting all together
theta.old <- c(pi.old,gamma.old,psi.old,delta.old)
theta.k <- theta.old
names.theta <- c(paste0('pi',1:K),paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx)),
paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int')),
paste0('NB.',c('Mean.Int','Disp.Int')),paste0('delta',1:B))
names(theta.k) <- names.theta
cat(paste0("Initialization completed!\n"));cat(paste0(c(rep('#',45),'\n')))
# EM algorithm begins
cat("The EM algorithm begins...\n")
while(count.em<maxcount.em & it.em<maxit.em){
it.em <- it.em+1
# Updating parameters
pi.k <- theta.k[paste0('pi',1:K)]
gamma.k <- matrix(theta.k[paste0('gamma',as.character(transform(expand.grid(1:K,1:K),idx=paste0(Var1,Var2))$idx))],nrow=K,ncol=K,byrow=F);k=(K+1);for(i in 1:K){for(j in 1:K){assign(paste0('gamma',j,i,'.k'),theta.k[k]);k=k+1}}
psi1.k <- theta.k[paste0('ZINB.',c('ZIP.Int','Mean.Int','Disp.Int'))]
psi2.k <- theta.k[paste0('NB.',c('Mean.Int','Disp.Int'))]
delta.k <- theta.k[paste0('delta',1:B)]
# Outer EM: E-step
loglik <- ZIHMM.LLmix2_Constr(dt = DTvec,psi1 = psi1.k,psi2 = psi2.k,delta = delta.k,N = N,M = M)
## Forward-Backward probabilities
logF <- hmm_logF(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)
logB <- hmm_logB(logf1 = loglik[,1], logf2 = loglik[,2],logf3 = loglik[,3], pi = pi.k,gamma=gamma.k)
## Posterior probabilities
DT[,paste0('PostProb',1:K):=as.data.table(check.prob(hmm_P1(logF=logF,logB=logB)))]
DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))):=as.data.table(check.prob(hmm_P2(logF=logF,logB=logB,logf1=loglik[,1],logf2=loglik[,2],logf3=loglik[,3],gamma=gamma.k)))]
# Outer EM: M-step
## Initial and transition probabilities
PostProb <- HMM.prob(DT)
pi.k1 <- PostProb$pi
gamma.k1 <- PostProb$gamma
zlist[[it.em]] <- Viterbi(LOGF=loglik,P=pi.k1,GAMMA=gamma.k1)
## Model parameters
### Inner EM
count.innerem <- 0
error.innerem <- 1
delta.tmp <- delta.k
psi1.tmp <- psi1.k
psi2.tmp <- psi2.k
rejection <- (pcut>0)*ifelse((0.9^it.em)>=pcut,(0.9^it.em),pcut)
while(count.innerem<maxit.innerem & error.innerem>epsilon.innerem){
#### Inner EM: E-step
eta <- ZIHMM.mixprob2_Constr(dt = DTvec,delta = delta.tmp,psi1 = psi1.tmp,psi2 = psi2.tmp,N = N,M = M)
#### Inner EM: M-step
##### Mixing probability
delta.new <- unlist(lapply(1:B,FUN = function(i){sum(DT[,PostProb2]*eta[,i])/sum(DT[,PostProb2])}));names(delta.new) = names(delta.tmp)
##### Model parameters
dtlist <- c(lapply(1, FUN = function(i){
###### Updating posterior probabilities with rejection-controlled EM
DT[,c('Rejection1') := list(PostProb1)][PostProb1<rejection,Rejection1 := stats::rbinom(.N,1,prob=PostProb1/rejection)*rejection]
###### Updating the vectorized dataset
DTvec[,c('Rejection1') := list(rep(DT[,Rejection1],N))]
###### Aggregating data
agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection1>0)',agg = 'Rejection1')
}),
lapply(2:(B+1),FUN = function(i){
###### Updating posterior probabilities with rejection-controlled EM
DT[,weights.tmp := PostProb2*eta[,(i-1)]][,paste0('Rejection2.',(i-1)) := list(weights.tmp)][weights.tmp<rejection,paste0('Rejection2.',(i-1)) := stats::rbinom(.N,1,prob=weights.tmp/rejection)*rejection]
###### Updating the vectorized dataset
DTvec[,paste0('Rejection2.',(i-1)) := list(rep(DT[[paste0('Rejection2.',(i-1))]],N))]
###### Aggregating data
agg(data = DTvec,data.unique = DTvec.unique,rows = paste0('(Rejection2.',(i-1),'>0)'),agg = paste0('Rejection2.',(i-1)))
}),
lapply(B+2, FUN = function(i){
###### Updating posterior probabilities with rejection-controlled EM
DT[,c('Rejection3') := list(PostProb3)][PostProb3<rejection,Rejection3 := stats::rbinom(.N,1,prob=PostProb3/rejection)*rejection]
###### Updating the vectorized dataset
DTvec[,c('Rejection3') := list(rep(DT[,Rejection3],N))]
###### Aggregating data
agg(data = DTvec,data.unique = DTvec.unique,rows = '(Rejection3>0)',agg = 'Rejection3')
}))
###### Calculating MLEs
tryCatch({assign('model',stats::optim(par = c(psi1.tmp,psi2.tmp),fn = glm.zinb2_Constr,gr = deriv.zinb2_Constr,method = 'L-BFGS-B',lower = rep(log(min.zero),5),upper = c(Inf,Inf,log(max.phi),Inf,log(max.phi)),
dt = rbindlist(dtlist,idcol = 'id'),maxid = 2+B))},
error=function(e){model<<-list();model[['par']]<<-c(psi1.tmp,psi2.tmp);model[['convergence']]<<-99})
###### Saving parameters
psi1.new <- model$par[1:3]
psi2.new <- model$par[4:5]
##### Checking convergence
error.innerem <- max(abs((c(delta.new,c(psi1.new,psi2.new))-c(delta.tmp,c(psi1.tmp,psi2.tmp)))/c(delta.tmp,c(psi1.tmp,psi2.tmp))))
##### Updating paramaters
delta.tmp <- delta.new
psi1.tmp <- psi1.new
psi2.tmp <- psi2.new
count.innerem <- count.innerem + 1
}
### Saving parameters
delta.k1 <- delta.new
psi.k1 <- c(psi1.new,psi2.new)
# Updating parameter history
theta.k1 <- c(pi.k1,gamma.k1,psi.k1,delta.k1);names(theta.k1) = names(theta.k)
theta.k <- theta.k1
parlist[[it.em]] <- c(it=it.em,BIC = getBIC(logF = logF,nPar = K*(K-1)+(K-1)+(B-1)+5,group = group),
Q=Q(as.matrix(DT[,list(PostProb1,PostProb2,PostProb3)]),as.matrix(DT[,paste0('JoinProb',c(sapply(1:K,FUN = function(x){sapply(1:K,FUN = function(y){paste0(x,y)})}))),with=F]),loglik,pi.k1,gamma.k1),
error=error.em[1],theta.k1,m=model$convergence)
# Computing EM error
gap <- ifelse(it.em>minit.em,gap.em,1)
if(it.em>1){
parlist.old <- parlist[[(it.em-gap)]][names(psi.k1)]
parlist.new <- parlist[[it.em]][names(psi.k1)]
zlist.table <- data.table(old = zlist[[(it.em-gap)]], new = zlist[[it.em]])
ACC <- 100*zlist.table[,.N,by=list(old,new)][(old==0 & new==0) | (old==1 & new==1) | (old==2 & new==2),sum(N)]/M
} else{
parlist.old <- rep(1,length(names(psi.k1)))
parlist.new <- rep(1,length(names(psi.k1)))
ACC <- 0
}
MRCPE <- max(abs((parlist.new-parlist.old)/parlist.old)) #Max. Abs. Rel. Change. of par. estimates
MACPE <- max(abs(parlist.new-parlist.old)) #Max. Abs. Change. of par. estimates
ARCEL <- ifelse(it.em>=2,abs((parlist[[it.em]][['Q']] - parlist[[(it.em-gap)]][['Q']])/parlist[[(it.em-gap)]][['Q']]),0) #Abs. Rel. Change of expected log-likelihood of complete data (Q function)
MULTI <- c(MRCPE,MACPE,ARCEL,100-ACC)
error.em <- (it.em>=2)*get(criterion) + (it.em<2)*rep(1,length(get(criterion)))
count.em <- as.numeric(any(error.em<=epsilon.em))*(it.em>minit.em)*(count.em+1) + 0
#Outputing history
if(!quiet){
cat(paste0(c(rep('#',45),'\n')))
cat('\rIteration: ',it.em,', Error(s): ',paste(formatC(error.em, format = "e", digits = 2),collapse = ', '),', Viterbi Agreement: ',round(ACC,2),'%.\n',sep='')
cat("\r",paste0('Q-function: ',round(parlist[[it.em]][['Q']],2),', BIC: ',round(parlist[[it.em]][['BIC']],2)),"\n")
cat("\r",paste('Max. abs. rel. change of parameter estimates: '),MRCPE,"\n")
cat("\r",paste('Max. abs. change of parameter estimates: '),MACPE,"\n")
cat("\r",paste('Abs. rel. change of Q-function: '),ARCEL,"\n")
cat(paste0(c(rep('#',45),'\n')))
}
}
# Organizing output
logF <- setnames(as.data.table(logF),c('Background','Differential','Enrichment'))
logB <- setnames(as.data.table(logB),c('Background','Differential','Enrichment'))
loglik <- setnames(as.data.table(loglik),c('Background','Differential','Enrichment'))
eta <- setnames(as.data.table(eta),do.call(paste0,Ref[2:(2+B-1),1:ngroup]))
prob <- setnames(DT[,list(PostProb1,PostProb2,PostProb3)],c('Background','Differential','Enrichment'))
viterbi <- ifelse(zlist[[it.em]]==0,'B',ifelse(zlist[[it.em]]==1,'D','E'))
colnames(eta) <- gsub('1','E',gsub('0','B',colnames(eta)))
metadata(object) <- list('pi'=pi.k1,
'gamma'=gamma.k1,
'psi'=psi.k1,
'prob'=prob,
'mixProb'=eta,
'viterbi'= viterbi,
'logF'=logF,
'logB'=logB,
'logLik'=loglik,
'parHist'=as.data.table(do.call(rbind,parlist)))
cat('Done!\n');cat(paste0(c(rep('#',45),'\n')))
return(object)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.