R/MHMM.R

Defines functions MHMM

Documented in MHMM

#' Mixture HMM with Bernoulli distributions
#'
#' Add description here
#'
#' @param object a mikadoDataSet
#' @param initializer the output of the function initializer
#' @param control list of control arguments from controlEM()
#'
#' @return
#' Add return here
#'
#' @author Pedro L. Baldoni, \email{pedrobaldoni@gmail.com}
#' @references
#' \url{https://github.com/plbaldoni/epigraHMM}
#'
#' @import data.table
#' @importFrom  stats rbinom dbinom glm quasibinomial
#' @rawNamespace import(SummarizedExperiment, except = shift)
#' @rawNamespace import(S4Vectors, except = c(first,second))
#'
#' @export
MHMM = function(object,initializer,control)
{
    ChIP = Group = JoinProb11 = JoinProb12 = JoinProb21 = JoinProb22 = PostProb1 = NULL
    PostProb2 = Rejection1 = Rejection2 = Var1 = Var2 = epsilon.em = NULL
    maxcount.em = maxit.em = minit.em = windowSize = pcut = quiet = NULL
    weights = z = clusters = cores = NULL
    
    # Creating control elements
    
    for(i in seq_along(control)){assign(names(control)[i],control[[i]])}
    
    # Registering number of cores
    
    doParallel::registerDoParallel(cores = cores)
    
    # General parameters
    
    clustInit <- colData(object)[['assignedClusters']]
    clustInitLevels <- unique(clustInit)
    
    M <- nrow(object);N <- ncol(object); L <- nlevels(clustInit)
    
    it.em <- 0
    count.em <- 0 
    error.em <- c('error' = 1)
    
    parlist <- list()
    theta.old <- sapply(c('pi','gamma','delta','psi'),function(x) NULL)
    theta.new <- sapply(c('pi','gamma','delta','psi'),function(x) NULL)
    
    # Parameter initializations
    
    # Putting all together
    
    theta.old[['gamma']] <- Matrix::bdiag(lapply(clustInitLevels,function(y){rowMeans(array(unlist(lapply(initializer$par[which(clustInit==y)],function(x){x[['gamma']]})),dim = c(2,2,sum(clustInit==y))),dims = 2)}))
    theta.old[['delta']] <- as.numeric(table(clustInit)/ncol(object))
    theta.old[['psi']] <- unname(unlist(lapply(clustInitLevels,function(y){colMeans(do.call(rbind,lapply(initializer$par[which(clustInit==y)],function(x){x[['psi']]})))})))
    theta.old[['pi']] <- unlist(lapply(clustInitLevels,function(y){colMeans(do.call(rbind,lapply(initializer$par[which(clustInit==y)],function(x){x[['pi']]})))}))
    theta.old[['pi']] <- theta.old[['pi']]*rep(theta.old[['delta']],each = 2)
    theta.old[['pi']] <- theta.old[['pi']]/sum(theta.old[['pi']])
    
    # EM algorithm begins
    message(paste0(c(rep('#',80))));message(Sys.time());message("Starting the EM algorithm")
    
    while(count.em<maxcount.em & it.em<maxit.em){
        it.em = it.em + 1
        
        # E-step
        ## Forward-Backward probabilities
        
        foreach::foreach(i = iterators::iter(seq_len(L)),.final = NULL) %dopar% {
            mhmmK_logFB(counts = assay(object,'counts'),
                        pi = theta.old$pi[c(2*i-1,2*i)],
                        gamma = as.matrix(theta.old$gamma[c(2*i-1,2*i),c(2*i-1,2*i)]),
                        lprob = do.call(rbind,lapply(theta.old$psi[c(2*i-1,2*i)],function(x){dbinom(c(0,1),size = 1,prob = logit(x),log = T)})),
                        name = file.path(dir,paste0(c('logF','logB'),i,'.bin')))
        }
        
        ## Cluster membership posterior probabilities
        
        logW <- mhmmK_logW(dirFP = file.path(dir,paste0('logF',seq_len(L),'.bin')),delta = theta.old$delta,M = M,N = N)
        
        ## Window-based posterior probabilities
        
        foreach::foreach(i = iterators::iter(seq_len(L)),.final = NULL) %dopar% {
            mhmmK_logPP(counts = assay(object,'counts'),
                        name = file.path(dir,paste0(c('logP','logT'),i,'.bin')),
                        delta = theta.old$delta,
                        gamma = as.matrix(theta.old$gamma[c(2*i-1,2*i),c(2*i-1,2*i)]),
                        lprob = do.call(rbind,lapply(theta.old$psi[c(2*i-1,2*i)],function(x){dbinom(c(0,1),size = 1,prob = logit(x),log = T)})),
                        dirFP = file.path(dir,paste0('logF',seq_len(L),'.bin')),
                        dirBP = file.path(dir,paste0('logB',seq_len(L),'.bin')),
                        M = M,
                        N = N,
                        cluster = i)
        }
        
        # M-step
        
        theta.new[['delta']] <- colSums(exp(logW))/sum(exp(logW))
        theta.new[['pi']] <- as.numeric(mhmmK_maxP(dirP = file.path(dir,paste0('logP',seq_len(L),'.bin')),M = M,N = N))
        theta.new[['gamma']] <- mhmmK_maxT(dirT = file.path(dir,paste0('logT',seq_len(L),'.bin')),M = M,N = N)
        
        ## Model parameters
        
        weights <- mhmmK_agg(counts = assay(object,'counts'),dirP = file.path(dir,paste0('logP',seq_len(L),'.bin')))
        fit <- lapply(seq_len(2*L),function(x){stats::glm(c(0,1)~1,weights = weights[x,],family = stats::quasibinomial(link = "logit"))})
        
        theta.new[['psi']] <- unname(unlist(lapply(fit,function(x){x[['coefficients']]})))
        
        # Updating parameter history
        
        error.em['error'] <- sum(unlist(lapply(names(theta.old),function(x){sqrt(sum((theta.old[[x]]-theta.new[[x]])^2))})))
        
        parlist[[it.em]] <- c(it=it.em,error.em,m=1*any(!unname(unlist(lapply(fit,function(x){x[['converged']]})))))
        
        theta.old <- theta.new
        
        # Computing EM error
        
        count.em <- as.numeric(error.em<=epsilon.em)*(it.em>minit.em)*(count.em+1) + 0
        
        #Outputing history
        if(!quiet){
            message(paste0(c(rep('#',80))))
            message('\rIteration: ',it.em,'. Error: ',paste(formatC(error.em, format = "e", digits = 2)),sep='')
            message("\r",paste('Mixture estimates: '),paste(formatC(theta.new[['delta']], format = "e", digits = 2),collapse = ' '))
            message("\r",paste('Intercept estimates: '),paste(formatC(theta.new[['psi']], format = "e", digits = 2),collapse = ' '))
            message("\r",paste('Initial prob. estimates: '),paste(formatC(theta.new[['pi']], format = "e", digits = 2),collapse = ' '))
            message(paste0(c(rep('#',80))))
        }
    }
    
    S4Vectors::metadata(object) <- list('pi'=theta.new[['pi']],
                                        'gamma'=theta.new[['gamma']],
                                        'delta'=theta.new[['delta']],
                                        'psi'=theta.new[['psi']],
                                        'prob'=exp(logW))
    
    message('Done!');message(Sys.time());message(paste0(c(rep('#',80))))
    return(object)
}
plbaldoni/mikado documentation built on June 9, 2020, 3:34 p.m.