R/callPeaksMultivariate.R

Defines functions transProbs prepareMultivariate runMultivariate callPeaksMultivariate

Documented in callPeaksMultivariate

#' Fit a Hidden Markov Model to multiple ChIP-seq samples
#'
#' Fit a HMM to multiple ChIP-seq samples to determine the combinatorial state of genomic regions. Input is a list of \code{\link{uniHMM}}s generated by \code{\link{callPeaksUnivariate}}.
#'
#' Emission distributions from the univariate HMMs are used with a Gaussian copula to generate a multivariate emission distribution for each combinatorial state. This multivariate distribution is then kept fixed and the transition probabilities are fitted with a Baum-Welch. Please refer to our manuscript at \url{http://dx.doi.org/10.1101/038612} for a detailed description of the method.
#'
#' @author Aaron Taudt, Maria Colome Tatche
#' @param hmms A list of \code{\link{uniHMM}}s generated by \code{\link{callPeaksUnivariate}}, e.g. \code{list(hmm1,hmm2,...)} or a vector of files that contain such objects, e.g. \code{c("file1","file2",...)}.
#' @param use.states A data.frame with combinatorial states which are used in the multivariate HMM, generated by function \code{\link{stateBrewer}}. If both \code{use.states} and \code{max.states} are \code{NULL}, the maximum possible number of combinatorial states will be used.
#' @param max.states Maximum number of combinatorial states to use in the multivariate HMM. The states are ordered by occurrence as determined from the combination of univariate state calls.
#' @param per.chrom If \code{per.chrom=TRUE} chromosomes will be treated separately. This tremendously speeds up the calculation but results might be noisier as compared to \code{per.chrom=FALSE}, where all chromosomes are concatenated for the HMM.
#' @param chromosomes A vector specifying the chromosomes to use from the models in \code{hmms}. The default (\code{NULL}) uses all available chromosomes.
#' @param eps Convergence threshold for the Baum-Welch algorithm.
#' @param keep.posteriors If set to \code{TRUE}, posteriors will be available in the output. This can be useful to change the posterior cutoff later, but increases the necessary disk space to store the result immensely.
#' @param num.threads Number of threads to use. Setting this to >1 may give increased performance.
#' @param max.time The maximum running time in seconds for the Baum-Welch algorithm. If this time is reached, the Baum-Welch will terminate after the current iteration finishes. The default \code{NULL} is no limit.
#' @param max.iter The maximum number of iterations for the Baum-Welch algorithm. The default \code{NULL} is no limit.
#' @param keep.densities If set to \code{TRUE} (default=\code{FALSE}), densities will be available in the output. This should only be needed debugging.
#' @param verbosity Verbosity level for the fitting procedure. 0 - No output, 1 - Iterations are printed.
#' @param temp.savedir A directory where to store intermediate results if \code{per.chrom=TRUE}.
#' @return A \code{\link{multiHMM}} object.
#' @seealso \code{\link{multiHMM}}, \code{\link{callPeaksUnivariate}}, \code{\link{callPeaksReplicates}}
#' @import doParallel
#' @import foreach
#' @export
#' @examples
#'# Get example BAM files for 2 different marks in hypertensive rat
#'file.path <- system.file("extdata","euratrans", package='chromstaRData')
#'files <- list.files(file.path, full.names=TRUE, pattern='SHR.*bam$')[c(1:2,6)]
#'# Construct experiment structure
#'exp <- data.frame(file=files, mark=c("H3K27me3","H3K27me3","H3K4me3"),
#'                  condition=rep("SHR",3), replicate=c(1:2,1), pairedEndReads=FALSE,
#'                  controlFiles=NA)
#'states <- stateBrewer(exp, mode='combinatorial')
#'# Bin the data
#'data(rn4_chrominfo)
#'binned.data <- list()
#'for (file in files) {
#'  binned.data[[basename(file)]] <- binReads(file, binsizes=1000, stepsizes=500,
#'                                            experiment.table=exp,
#'                                            assembly=rn4_chrominfo, chromosomes='chr12')
#'}
#'# Obtain the univariate fits
#'models <- list()
#'for (i1 in 1:length(binned.data)) {
#'  models[[i1]] <- callPeaksUnivariate(binned.data[[i1]], max.time=60, eps=1)
#'}
#'# Call multivariate peaks
#'multimodel <- callPeaksMultivariate(models, use.states=states, eps=1, max.time=60)
#'# Check some plots
#'heatmapTransitionProbs(multimodel)
#'heatmapCountCorrelation(multimodel)
#'
callPeaksMultivariate <- function(hmms, use.states, max.states=NULL, per.chrom=TRUE, chromosomes=NULL, eps=0.01, keep.posteriors=FALSE, num.threads=1, max.time=NULL, max.iter=NULL, keep.densities=FALSE, verbosity=1, temp.savedir=NULL) {

    ## Intercept user input
    if (!is.null(use.states)) {
        if (!is(use.states,'data.frame')) stop("argument 'use.states' expects a data.frame generated by function 'state.brewer'")
    }
    if (!is.null(max.states)) {
        if (check.positive.integer(max.states)!=0) stop("argument 'max.states' expects a positive integer")
    }
    if (!is.null(use.states) & !is.null(max.states)) {
        if (max.states > nrow(use.states)) {
            max.states <- nrow(use.states)
        }
    }
    if (check.positive(eps)!=0) stop("argument 'eps' expects a positive numeric")
    if (check.positive.integer(num.threads)!=0) stop("argument 'num.threads' expects a positive integer")
    if (is.null(max.time)) { max.time <- -1 } else if (check.nonnegative.integer(max.time)!=0) { stop("argument 'max.time' expects a non-negative integer") }
    if (is.null(max.iter)) { max.iter <- -1 } else if (check.nonnegative.integer(max.iter)!=0) { stop("argument 'max.iter' expects a non-negative integer") }
    if (check.logical(keep.posteriors)!=0) stop("argument 'keep.posteriors' expects a logical (TRUE or FALSE)")
    if (check.logical(keep.densities)!=0) stop("argument 'keep.densities' expects a logical (TRUE or FALSE)")
    if (check.integer(verbosity)!=0) stop("argument 'verbosity' expects an integer")
    if (length(hmms)==0) {
        stop("argument 'hmms' is of length=0. Cannot call multivariate peaks with no models.")
    }
    if (length(hmms)==1) {
        hmm <- loadHmmsFromFiles(hmms, check.class=class.univariate.hmm)[[1]]
        ### Make return object ###
            result <- list()
            result$info <- hmm$info
            if (is.null(result$info)) {
                n <- 1
                result$info <- data.frame(file=rep(NA, n), mark=1:n, condition=1:n, replicate=1, pairedEndReads=rep(NA, n), controlFiles=rep(NA, n))
                result$info$ID <- paste0(result$info$mark, '-', result$info$condition, '-rep', result$info$replicate)
            }
        ## Counts
            result$bincounts <- hmm$bincounts
            result$bincounts$counts <- array(NA, dim = c(length(hmm$bincounts), 1, dim(hmm$bincounts$counts)[2]), dimnames = list(bin=NULL, track=result$info$ID, offset=dimnames(hmm$bincounts$counts)[[2]]))
            result$bincounts$counts[,1,] <- hmm$bincounts$counts
        ## Bin coordinates, posteriors and states
            result$bins <- hmm$bins
            result$bins$score <- NULL
            result$bins$posterior.modified <- NULL
            result$bins$counts.rpkm <- matrix(hmm$bins$counts.rpkm, ncol=1)
            result$bins$state <- factor(c(0,0,1))[hmm$bins$state]
            result$bins$posteriors <- matrix(hmm$bins$posterior.modified, ncol=1, dimnames=list(NULL, result$info$ID))
            result$bins$differential.score <- 0
        ## Add combinations
            mapping <- NULL
            if (!is.null(use.states)) {
                mapping <- use.states$combination
                names(mapping) <- use.states$state
                result$bins$combination <- factor(mapping[as.character(result$bins$state)], levels=levels(use.states$combination))
            } else {
                result$bins$combination <- result$bins$state
            }
        ## Segmentation
            result$segments <- multivariateSegmentation(result$bins, column2collapseBy='state')
            mcols(result$segments)[2] <- NULL
            if (!keep.posteriors) {
                result$bins$posteriors <- NULL
            }
        ## Parameters
            result$mapping <- mapping
            combinations <- mapping[as.character(c(0,1))]
            # Weights
            tstates <- table(result$bins$state)
            result$weights <- sort(tstates/sum(tstates), decreasing=TRUE)
            result$weights.univariate <- list(hmm$weights)
            names(result$weights.univariate) <- result$info$ID
            # Transition matrices
            result$transitionProbs <- NA
            result$transitionProbs.initial <- NA
            # Initial probs
            result$startProbs <- NA
            result$startProbs.initial <- NA
            # Distributions
            result$distributions <- list(hmm$distributions)
            names(result$distributions) <- result$info$ID
        ## Convergence info
            convergenceInfo <- list(eps=NA, loglik=NA, loglik.delta=NA, num.iterations=NA, time.sec=NA)
            result$convergenceInfo <- convergenceInfo
        ## Correlation matrices
            result$correlation.matrix <- NA
        ## Add class
            class(result) <- class.multivariate.hmm
        
        return(result)
    }

    ## Prepare the HMM
    p <- prepareMultivariate(hmms, use.states=use.states, max.states=max.states, chromosomes=chromosomes)

    if (is.null(chromosomes)) {
        chromosomes <- intersect(seqlevels(p$bincounts), unique(seqnames(p$bincounts)))
    }
    if (!is.null(temp.savedir)) {
        if (!file.exists(temp.savedir)) {
            dir.create(temp.savedir)
        }
    }

    ## Run multivariate per chromosome
    if (per.chrom) {

        # Set up parallel execution
        if (num.threads > 1) {
            ptm <- startTimedMessage("Setting up parallel multivariate with ", num.threads, " threads ...")
            cl <- parallel::makeCluster(num.threads)
            doParallel::registerDoParallel(cl)
            on.exit(
                if (num.threads > 1) {
                    parallel::stopCluster(cl)
                }
            )
            stopTimedMessage(ptm)
        }

        # Run the multivariate
        chrom <- NULL # please BiocCheck
        if (num.threads > 1) {
            ptm <- startTimedMessage("Running multivariate ...")
            models <- foreach (chrom = chromosomes, .packages='chromstaR') %dopar% {
                bincounts <- p$bincounts[seqnames(p$bincounts)==chrom]
                bins <- p$bins[seqnames(p$bins)==chrom]
                if (!is.null(temp.savedir)) {
                    temp.savename <- file.path(temp.savedir, paste0('chromosome_', chrom, '.RData'))
                    if (!file.exists(temp.savename)) {
                        model <- runMultivariate(binned.data=bincounts, stepbins=bins, info=p$info, comb.states=p$comb.states, use.states=p$use.states, distributions=p$distributions, weights=p$weights, correlationMatrix=p$correlationMatrix, correlationMatrixInverse=p$correlationMatrixInverse, determinant=p$determinant, max.iter=max.iter, max.time=max.time, eps=eps, num.threads=1, keep.posteriors=keep.posteriors, keep.densities=keep.densities, verbosity=verbosity)
                        save(model, file=temp.savename)
                        rm(model); gc()
                    }
                    temp.savename
                } else {
                    model <- runMultivariate(binned.data=bincounts, stepbins=bins, info=p$info, comb.states=p$comb.states, use.states=p$use.states, distributions=p$distributions, weights=p$weights, correlationMatrix=p$correlationMatrix, correlationMatrixInverse=p$correlationMatrixInverse, determinant=p$determinant, max.iter=max.iter, max.time=max.time, eps=eps, num.threads=1, keep.posteriors=keep.posteriors, keep.densities=keep.densities, verbosity=verbosity)
                    model
                }
            }
            stopTimedMessage(ptm)
        } else {
            models <- list()
            for (chrom in chromosomes) {
                ptm <- messageU("Chromosome = ", chrom, overline="-", underline=NULL)
                bincounts <- p$bincounts[seqnames(p$bincounts)==chrom]
                bins <- p$bins[seqnames(p$bins)==chrom]
                if (!is.null(temp.savedir)) {
                    temp.savename <- file.path(temp.savedir, paste0('chromosome_', chrom, '.RData'))
                    if (!file.exists(temp.savename)) {
                        model <- runMultivariate(binned.data=bincounts, stepbins=bins, info=p$info, comb.states=p$comb.states, use.states=p$use.states, distributions=p$distributions, weights=p$weights, correlationMatrix=p$correlationMatrix, correlationMatrixInverse=p$correlationMatrixInverse, determinant=p$determinant, max.iter=max.iter, max.time=max.time, eps=eps, num.threads=1, keep.posteriors=keep.posteriors, keep.densities=keep.densities, verbosity=verbosity)
                        ptm <- startTimedMessage("Saving chromosome ", chrom, " to temporary file ", temp.savename, " ...")
                        save(model, file=temp.savename)
                        rm(model); gc()
                        stopTimedMessage(ptm)
                    }
                    models[[as.character(chrom)]] <- temp.savename
                } else {
                    model <- runMultivariate(binned.data=bincounts, stepbins=bins, info=p$info, comb.states=p$comb.states, use.states=p$use.states, distributions=p$distributions, weights=p$weights, correlationMatrix=p$correlationMatrix, correlationMatrixInverse=p$correlationMatrixInverse, determinant=p$determinant, max.iter=max.iter, max.time=max.time, eps=eps, num.threads=1, keep.posteriors=keep.posteriors, keep.densities=keep.densities, verbosity=verbosity)
                    models[[as.character(chrom)]] <- model
                }
                message("Time spent for chromosome = ", chrom, ":", appendLF=FALSE)
                stopTimedMessage(ptm)
            }
        }

        # Merge chromosomes into one multiHMM
        ptm <- startTimedMessage("Merging chromosomes ...")
        if (!is.null(temp.savedir)) {
            models <- as.character(models) # make sure 'models' is a character vector with filenames and not a list()
        }
        model <- suppressMessages( mergeChroms(models) )
        stopTimedMessage(ptm)

    ## Run multivariate for all chromosomes
    } else {

        model <- runMultivariate(binned.data=p$bincounts, stepbins=p$bins, info=p$info, comb.states=p$comb.states, use.states=p$use.states, distributions=p$distributions, weights=p$weights, correlationMatrix=p$correlationMatrix, correlationMatrixInverse=p$correlationMatrixInverse, determinant=p$determinant, max.iter=max.iter, max.time=max.time, eps=eps, num.threads=num.threads, keep.posteriors=keep.posteriors, keep.densities=keep.densities, verbosity=verbosity)

    }

    if (!is.null(temp.savedir)) {
        if (file.exists(temp.savedir)) {
            unlink(temp.savedir, recursive = TRUE)
        }
    }
    return(model)

}

#' @importFrom stats ecdf
runMultivariate <- function(binned.data, stepbins, info, comb.states, use.states, distributions, weights, correlationMatrix, correlationMatrixInverse, determinant, max.iter, max.time, eps, num.threads, keep.posteriors, keep.densities, transitionProbs.initial=NULL, startProbs.initial=NULL, verbosity=1) {

    ptm.start <- startTimedMessage("Starting multivariate HMM with ", length(comb.states), " combinatorial states")
    message("")

    ### Variables ###
    nummod <- dim(binned.data$counts)[2]
    offsets <- dimnames(binned.data$counts)[[3]]
    binstates <- dec2bin(comb.states, ndigits=nummod)
    
    # Prepare input for C function
    rs <- unlist(lapply(distributions,"[",2:3,'size'))
    ps <- unlist(lapply(distributions,"[",2:3,'prob'))
    ws1 <- unlist(lapply(weights,"[",1))
    ws2 <- unlist(lapply(weights,"[",2))
    ws3 <- unlist(lapply(weights,"[",3))
    ws <- ws1 / (ws2+ws1)
    get.posteriors <- TRUE
    if (get.posteriors) { lenPosteriors <- length(binned.data) * length(comb.states) } else { lenPosteriors <- 1 }
    if (keep.densities) { lenDensities <- length(binned.data) * length(comb.states) } else { lenDensities <- 1 }

    if (is.null(transitionProbs.initial)) {
        transitionProbs.initial <- matrix((1-0.9)/(length(comb.states)-1), ncol=length(comb.states), nrow=length(comb.states))
        diag(transitionProbs.initial) <- 0.9
    }
    if (is.null(startProbs.initial)) {
        startProbs.initial <- rep(1/length(comb.states), length(comb.states))
    }

    ### Arrays for finding maximum posterior for each bin between offsets
    ## Make bins with offset
    ptm <- startTimedMessage("Making bins with offsets ...")
    # if (is.null(stepbins)) {
    #     if (length(offsets) > 1) {
    #         stepbins <- suppressMessages( fixedWidthBins(chrom.lengths = seqlengths(binned.data), binsizes = as.numeric(offsets[2]), chromosomes = unique(seqnames(binned.data)))[[1]] )
    #     } else {
    #         stepbins <- binned.data
    #         mcols(stepbins) <- NULL
    #     }
    #     counts.rpkm <- NULL
    # } else {
    #     counts.rpkm <- stepbins$counts.rpkm
    # }
    ## Dummy bins without mcols for shifting
    sbins <- binned.data
    mcols(sbins) <- NULL
    ## Initialize arrays
    if (get.posteriors) {
        aposteriors.step <- array(0, dim = c(length(stepbins), nummod, 2), dimnames = list(bin=NULL, track=info$ID, offset=c('previousOffsets', 'currentOffset'))) # to store posteriors-per-sample for current and max-of-previous offsets
    }
    # if (is.null(counts.rpkm)) {
    #     acounts.step <- array(0, dim = c(length(stepbins), nummod, 2), dimnames = list(bin=NULL, track=info$ID, offset=c('previousOffsets', 'currentOffset'))) # to store counts for current and sum-of-previous offsets
    # }
    amaxPosterior.step <- array(0, dim = c(length(stepbins), 2), dimnames = list(bin=NULL, offset=c('previousOffsets', 'currentOffset'))) # to store maximum posterior for current and max-of-previous offsets
    astates.step <- array(0, dim = c(length(stepbins), 2), dimnames = list(bin=NULL, offset=c('previousOffsets', 'currentOffset'))) # to store states for current and max-of-previous offsets
    stopTimedMessage(ptm)
    
    ### Loop over offsets ###
    for (ioffset in 1:length(offsets)) {
      
        offset <- offsets[ioffset]
        if (ioffset > 1) {
            ptm <- startTimedMessage("Obtaining states for offset = ", offset, " ...")
            ## Run only one iteration (no updating) if we are already over ioffset=1
            max.iter <- 1
            transitionProbs.initial <- hmm.A
            startProbs.initial <- hmm.proba
            verbosity <- 0
        } else {
            ptm <- startTimedMessage("Fitting Hidden Markov Model for offset = ", offset, "\n")
        }
      
        # Call the C function
        on.exit(.C("C_multivariate_cleanup", as.integer(nummod)))
        hmm <- .C("C_multivariate_hmm",
            counts = as.integer(as.vector(binned.data$counts[,, offset])), # int* multiO
            num.bins = as.integer(length(binned.data)), # int* T
            max.states = as.integer(length(comb.states)), # int* N
            num.modifications = as.integer(nummod), # int* Nmod
            comb.states = as.numeric(comb.states), # double* comb_states
            size = as.double(rs), # double* size
            prob = as.double(ps), # double* prob
            w = as.double(ws), # double* w
            correlation.matrix.inverse = as.double(correlationMatrixInverse), # double* cor_matrix_inv
            determinant = as.double(determinant), # double* det
            num.iterations = as.integer(max.iter), # int* maxiter
            time.sec = as.integer(max.time), # double* maxtime
            loglik.delta = as.double(eps), # double* eps
            posteriors = double(length=lenPosteriors), # double* posteriors
            get.posteriors = as.logical(get.posteriors), # bool* keep_posteriors
            densities = double(length=lenDensities), # double* densities
            keep.densities = as.logical(keep.densities), # bool* keep_densities
            states = double(length=length(binned.data)), # double* states
            maxPosterior = double(length=length(binned.data)), # double* maxPosterior
            A = double(length=length(comb.states)*length(comb.states)), # double* A
            proba = double(length=length(comb.states)), # double* proba
            loglik = double(length=1), # double* loglik
            A.initial = as.double(transitionProbs.initial), # double* initial A
            proba.initial = as.double(startProbs.initial), # double* initial proba
            use.initial.params = as.logical(TRUE), # bool* use_initial_params
            num.threads = as.integer(num.threads), # int* num_threads
            error = as.integer(0), # error handling
            verbosity = as.integer(verbosity) # int* verbosity
            )
                
        ### Check convergence ###
        if (hmm$error == 1) {
            stop("A nan occurred during the Baum-Welch! Parameter estimation terminated prematurely. Check your read counts for very high numbers, they could be the cause for this problem.")
        } else if (hmm$error == 2) {
            stop("An error occurred during the Baum-Welch! Parameter estimation terminated prematurely.")
        }
        
        if (ioffset == 1) {
            ### Make return object ###
                result <- list()
                result$info <- info
            ## Parameters
                if (!is.null(use.states)) {
                    mapping <- use.states$combination
                    names(mapping) <- use.states$state
                } else {
                    mapping <- NULL
                }
                result$mapping <- mapping
                combinations <- mapping[as.character(comb.states)]
                # Weights
                tstates <- table(hmm$states)
                result$weights <- sort(tstates/sum(tstates), decreasing=TRUE)
                result$weights.univariate <- weights
                names(result$weights.univariate) <- result$info$ID
                # Transition matrices
                result$transitionProbs <- matrix(hmm$A, ncol=length(comb.states), byrow=TRUE)
                colnames(result$transitionProbs) <- combinations
                rownames(result$transitionProbs) <- combinations
                result$transitionProbs.initial <- matrix(hmm$A.initial, ncol=length(comb.states), byrow=TRUE)
                colnames(result$transitionProbs.initial) <- combinations
                rownames(result$transitionProbs.initial) <- combinations
                # Initial probs
                result$startProbs <- hmm$proba
                names(result$startProbs) <- combinations
                result$startProbs.initial <- hmm$proba.initial
                names(result$startProbs.initial) <- combinations
                # Distributions
                result$distributions <- distributions
                names(result$distributions) <- result$info$ID
            ## Convergence info
                convergenceInfo <- list(eps=eps, loglik=hmm$loglik, loglik.delta=hmm$loglik.delta, num.iterations=hmm$num.iterations, time.sec=hmm$time.sec)
                result$convergenceInfo <- convergenceInfo
            ## Correlation matrices
                result$correlation.matrix <- correlationMatrix
            ## Add class
                class(result) <- class.multivariate.hmm
            
            ## Check convergence
            if (result$convergenceInfo$loglik.delta > result$convergenceInfo$eps) {
                war <- warning("HMM did not converge!\n")
            }
        }
        
        ## Store counts and posteriors in list
        if (get.posteriors) {
            dim(hmm$posteriors) <- c(length(binned.data), length(comb.states))
            dimnames(hmm$posteriors) <- list(bin=NULL, state=comb.states)
        }
        dim(hmm$counts) <- c(length(binned.data), nummod)
        dimnames(hmm$counts) <- list(bin=NULL, track=info$ID)
        if (keep.densities) {
            densities <- hmm$densities
            dim(densities) <- c(length(binned.data), length(comb.states))
            dimnames(densities) <- list(bin=NULL, state=comb.states)
        }
        hmm.A <- hmm$A
        hmm.proba <- hmm$proba
        
        ## Transform posteriors to 'per-sample' representation
        if (get.posteriors) {
            post.per.track <- hmm$posteriors %*% binstates
            colnames(post.per.track) <- info$ID
        }
        
        ## Inflate posteriors, states, counts to new offset
        bins.shift <- suppressWarnings( shift(sbins, shift = as.numeric(offset)) )
        ind <- findOverlaps(stepbins, bins.shift)
        if (get.posteriors) {
            aposteriors.step[ind@from, , 'currentOffset'] <- post.per.track[ind@to, , drop=FALSE]
        }
        # if (is.null(counts.rpkm)) {
        #     acounts.step[ind@from, , 'currentOffset'] <- hmm$counts[ind@to, , drop=FALSE]
        #     ## Sum over previous counts
        #     acounts.step[, , 'previousOffsets'] <- acounts.step[, , 'previousOffsets', drop=FALSE] + acounts.step[, , 'currentOffset', drop=FALSE]
        # 
        # }
        astates.step[ind@from, 'currentOffset'] <- hmm$states[ind@to]
        amaxPosterior.step[ind@from, 'currentOffset'] <- hmm$maxPosterior[ind@to]
        
        ## Find offset that maximizes the posteriors for each bin
        ##-- Start stuff to call C code
        # Work with changing dimensions to avoid copies being made
        dim_amaxPosterior.step <- dim(amaxPosterior.step)
        dimnames_amaxPosterior.step <- dimnames(amaxPosterior.step)
        dim(amaxPosterior.step) <- NULL
        z <- .C("C_array2D_which_max",
                array2D = amaxPosterior.step,
                dim = as.integer(dim_amaxPosterior.step),
                ind_max = integer(dim_amaxPosterior.step[1]),
                value_max = double(dim_amaxPosterior.step[1]))
        dim(amaxPosterior.step) <- dim_amaxPosterior.step
        dimnames(amaxPosterior.step) <- dimnames_amaxPosterior.step
        ind <- z$ind_max
        ##-- End stuff to call C code
        for (i1 in 1:2) {
            mask <- ind == i1
            if (get.posteriors) {
                aposteriors.step[mask, , 'previousOffsets'] <- aposteriors.step[mask,,i1, drop=FALSE]
            }
            astates.step[mask, 'previousOffsets'] <- astates.step[mask,i1, drop=FALSE]
            amaxPosterior.step[mask, 'previousOffsets'] <- amaxPosterior.step[mask,i1, drop=FALSE]
        }
        
        if (ioffset == 1) {
            message("Time spent in multivariate HMM: ", appendLF=FALSE)
        }
        stopTimedMessage(ptm)
    
        rm(hmm, ind); gc()
    } # loop over offsets
    states.step <- astates.step[, 'previousOffsets']
    rm(amaxPosterior.step, astates.step); gc()

    # Average and normalize counts to RPKM
    ptm <- startTimedMessage("Collecting counts and posteriors over offsets ...")
    # if (is.null(counts.rpkm)) {
    #     counts.step <- acounts.step[, , 'previousOffsets'] / length(offsets)
    #     rm(acounts.step); gc()
    #     counts.step <- rpkm.matrix(counts.step, binsize = width(binned.data)[1])
    # } else {
    #     counts.step <- counts.rpkm
    # }
    if (get.posteriors) {
        stepbins$posteriors <- aposteriors.step[,,'previousOffsets']
        rm(aposteriors.step); gc()
    }
    stopTimedMessage(ptm)
    
    ptm <- startTimedMessage("Compiling coordinates, posteriors, states ...")
    ## Counts ##
    result$bincounts <- binned.data
    result$bincounts$state <- NULL
    
    ## Bin coordinates, posteriors and states ##
    result$bins <- stepbins
    # result$bins$counts.rpkm <- counts.step
    if (!is.null(use.states$state)) {
        state.levels <- levels(use.states$state)
    } else {
        state.levels <- comb.states
    }
    result$bins$state <- factor(states.step, levels=state.levels)
    stopTimedMessage(ptm)
    
    if (keep.densities) {
        result$bincounts$densities <- densities
    }
    
    ## Add combinations ##
    ptm <- startTimedMessage("Adding combinations ...")
    if (!is.null(use.states)) {
        result$bins$combination <- factor(mapping[as.character(result$bins$state)], levels=levels(use.states$combination))
    } else {
        result$bins$combination <- result$bins$state
    }
    stopTimedMessage(ptm)
    
    ## Segmentation ##
    result$segments <- multivariateSegmentation(result$bins, column2collapseBy='state')
    ptm <- startTimedMessage("Adding differential score ...")
    result$segments$differential.score <- differentialScoreSum(result$segments$maxPostInPeak, result$info)
    stopTimedMessage(ptm)
    if (!keep.posteriors) {
        result$bins$posteriors <- NULL
    }
    if (get.posteriors) {
        ptm <- startTimedMessage("Getting maximum posterior in peaks ...")
        ind <- findOverlaps(result$bins, result$segments)
        result$bins$maxPostInPeak <- result$segments$maxPostInPeak[subjectHits(ind), , drop=FALSE]
        result$bins$differential.score <- result$segments$differential.score[subjectHits(ind)]
        stopTimedMessage(ptm)
            
        ## Peaks ##
        ptm <- startTimedMessage("Obtaining peaks ...")
        result$peaks <- list()
        for (i1 in 1:ncol(result$segments$maxPostInPeak)) {
            mask <- result$segments$maxPostInPeak[,i1] > 0
            peaks <- result$segments[mask]
            mcols(peaks) <- NULL
            peaks$maxPostInPeak <- result$segments$maxPostInPeak[mask,i1]
            result$peaks[[i1]] <- peaks
        }
        names(result$peaks) <- colnames(result$segments$maxPostInPeak)
        stopTimedMessage(ptm)
    }
        
    return(result)

}


#' @importFrom stats pnbinom qnorm dnbinom
prepareMultivariate = function(hmms, use.states=NULL, max.states=NULL, chromosomes=NULL) {

    nummod <- length(hmms)
    if (nummod > 53) {
        stop("We can't handle more than 53 'hmms' Please decrease the number of input 'hmms'.")
    }

    ## Load first HMM for coordinates
    ptm <- startTimedMessage("Getting coordinates ...")
    hmm <- suppressMessages( loadHmmsFromFiles(hmms[[1]], check.class=class.univariate.hmm)[[1]] )
  # hmm$bins$counts <- array(hmm$bins$counts, dim=c(length(hmm$bins), 1), dimnames=list(bin=NULL, offset='0'))
  # hmm$bincounts <- hmm$bins
  # hmm$bins$counts.rpkm <- rpkm.matrix(hmm$bins$counts, width(hmm$bins)[1])
    bincounts <- hmm$bincounts
    mcols(bincounts) <- NULL
    bins <- hmm$bins
    mcols(bins) <- NULL
    stopTimedMessage(ptm)

    if (!is.null(chromosomes)) {
        chromsNotInData <- setdiff(chromosomes, unique(seqnames(bincounts)))
        if (length(chromsNotInData) == length(chromosomes)) {
            stop("None of the specified chromosomes '", paste0(chromsNotInData, collapse=', '), "' exists.")
        }
        if (length(chromsNotInData)>0) {
            warning("Chromosomes '", paste0(chromsNotInData, collapse=', '), "' could not be found.")
        }
    }

    ## Go through HMMs and extract stuff
    ptm <- startTimedMessage("Extracting read counts ...")
    info <- list(hmm$info)
    distributions <- list(hmm$distributions)
    weights <- list(hmm$weights)
    offsets <- dimnames(hmm$bincounts$counts)[[2]]
    counts <- array(NA, dim = c(length(bincounts), nummod, length(offsets)), dimnames = list(bin=NULL, track=NULL, offset=offsets))
    counts[,1,] <- hmm$bincounts$counts
    counts.rpkm <- array(NA, dim = c(length(hmm$bins), nummod), dimnames = list(bin=NULL, track=NULL))
    counts.rpkm[,1] <- hmm$bins$counts.rpkm
    binary_statesmatrix <- matrix(NA, ncol=nummod, nrow=length(bincounts))
    bins.state <- hmm$bins$state[seq(from=1, to=length(hmm$bins), by=length(hmm$bins)/length(hmm$bincounts))]
    binary_statesmatrix[,1] <- c(FALSE,FALSE,TRUE)[bins.state]
    for (i1 in 2:nummod) {
        hmm <- suppressMessages( loadHmmsFromFiles(hmms[[i1]], check.class=class.univariate.hmm)[[1]] )
    # hmm$bins$counts <- array(hmm$bins$counts, dim=c(length(hmm$bins), 1), dimnames=list(bin=NULL, offset='0'))
    # hmm$bincounts <- hmm$bins
    # hmm$bins$counts.rpkm <- rpkm.matrix(hmm$bins$counts, width(hmm$bins)[1])
        info[[i1]] <- hmm$info
        distributions[[i1]] <- hmm$distributions
        weights[[i1]] <- hmm$weights
        counts[,i1,] <- hmm$bincounts$counts
        counts.rpkm[,i1] <- hmm$bins$counts.rpkm
        bins.state <- hmm$bins$state[seq(from=1, to=length(hmm$bins), by=length(hmm$bins)/length(hmm$bincounts))]
        binary_statesmatrix[,i1] <- c(FALSE,FALSE,TRUE)[bins.state] # F,F,T corresponds to levels 'zero-inflation','unmodified','modified'
    }
    info <- do.call(rbind, info)
    rownames(info) <- NULL
    if (is.null(info)) {
        n <- nummod
        info <- data.frame(file=rep(NA, n), mark=1:n, condition=1:n, replicate=rep(1, n), pairedEndReads=rep(NA, n), controlFiles=rep(NA, n))
        info$ID <- paste0(info$mark, '-', info$condition, '-rep', info$replicate)
    }
    bincounts$counts <- counts
    colnames(bincounts$counts) <- info$ID
    bins$counts.rpkm <- counts.rpkm
    colnames(bins$counts.rpkm) <- info$ID
    maxcounts <- max(bincounts$counts)
    bincounts$binary_statesmatrix <- binary_statesmatrix
    if (!is.null(chromosomes)) {
        # Select only specified chromosomes
        bincounts <- bincounts[seqnames(bincounts) %in% chromosomes]
    }
    stopTimedMessage(ptm)

    # Clean up to reduce memory usage
    remove(hmm)

    ## Transform binary to decimal
    ptm <- startTimedMessage("Getting combinatorial states ...")
    decimal_states <- rep(0,length(bincounts))
    for (imod in 1:nummod) {
        decimal_states <- decimal_states + 2^(nummod-imod) * bincounts$binary_statesmatrix[,imod]
    }
    bincounts$binary_statesmatrix <- NULL
    bincounts$state <- decimal_states
    if (is.null(use.states)) {
        comb.states.table <- sort(table(bincounts$state), decreasing=TRUE)
        comb.states <- names(comb.states.table)
    } else {
        comb.states.table <- sort(table(bincounts$state)[as.character(use.states$state)], decreasing=TRUE)
        comb.states <- names(comb.states.table)
        comb.states <- c(comb.states, setdiff(use.states$state, comb.states))
    }
    # Subselect states
    numstates2use <- min(length(comb.states), max.states)
    comb.states <- comb.states[1:numstates2use]
    stopTimedMessage(ptm)

    ## We pre-compute the z-values for each number of counts
    ptm <- startTimedMessage("Computing pre z-matrix ...")
    z.per.read <- array(NA, dim=c(maxcounts+1, nummod, 2))
    xcounts = 0:maxcounts
    for (imod in 1:nummod) {
        # Go through unmodified and modified
        for (i1 in 2:3) {
            
            size = distributions[[imod]][i1,'size']
            prob = distributions[[imod]][i1,'prob']

            if (i1 == 2) {
                # Unmodified with zero inflation
                w = weights[[imod]][1] / (weights[[imod]][2] + weights[[imod]][1])
                u = pzinbinom(xcounts, w, size, prob)
            } else if (i1 == 3) {
                # Modified
                u = stats::pnbinom(xcounts, size, prob)
            }

            # Check for infinities in u and set them to max value which is not infinity
            qnorm_u = stats::qnorm(u)
            mask <- qnorm_u==Inf
            qnorm_u[mask] <- stats::qnorm(1-1e-16)
#             testvec = qnorm_u!=Inf
#             qnorm_u = ifelse(testvec, qnorm_u, max(qnorm_u[testvec]))
            z.per.read[ , imod, i1-1] <- qnorm_u

        }
    }
    stopTimedMessage(ptm)

    ## Compute the z matrix
    ptm <- startTimedMessage("Transfering values into z-matrix ...")
    z.per.bin = array(NA, dim=c(length(bincounts), nummod, 2), dimnames=list(bin=1:length(bincounts), track=info$ID, state.labels[2:3]))
    for (imod in 1:nummod) {
        for (i1 in 1:2) {
            z.per.bin[ , imod, i1] <- z.per.read[bincounts$counts[,imod,'0']+1, imod, i1]
        }
    }

    # Clean up to reduce memory usage
    remove(z.per.read)
    stopTimedMessage(ptm)

    ### Calculate correlation matrix
    ptm <- startTimedMessage("Computing inverse of correlation matrix ...")
    correlationMatrix = array(NA, dim=c(nummod,nummod,length(comb.states)), dimnames=list(track=info$ID, track=info$ID, comb.state=comb.states))
    correlationMatrixInverse = array(NA, dim=c(nummod,nummod,length(comb.states)), dimnames=list(track=info$ID, track=info$ID, comb.state=comb.states))
    determinant = rep(NA, length(comb.states))
    names(determinant) <- comb.states

    numericToBin <- function(x, num.digits) {
        bin <- logical(length=num.digits)
        xi <- x
        i1 <- num.digits
        while (xi > 0) {
            bin[i1] <- xi %% 2
            xi <- xi %/% 2
            i1 <- i1 - 1
        }
        return(bin)
    }
    
    ## Calculate correlation matrix serial
    for (state in comb.states) {
        istate = which(comb.states==state)
        mask = which(bincounts$state==as.numeric(state))
        # Convert state to binary representation
        binary_state <- numericToBin(as.numeric(state), nummod)
        # Subselect z
        z.temp <- matrix(NA, ncol=nummod, nrow=length(mask))
        for (i1 in 1:length(binary_state)) {
            z.temp[,i1] <- z.per.bin[mask, i1, binary_state[i1]+1]
        }
        temp <- tryCatch({
            if (nrow(z.temp) > 100) {
                correlationMatrix[,,istate] <- suppressWarnings( cor(z.temp) )
                correlationMatrix[,,istate][is.na(correlationMatrix[,,istate])] <- 0
                determinant[istate] <- det( correlationMatrix[,,istate] )
                correlationMatrixInverse[,,istate] <- chol2inv(chol(correlationMatrix[,,istate]))
            } else {
                correlationMatrix[,,istate] <- diag(nummod)
                determinant[istate] <- 1 # det(diag(x)) == 1
                correlationMatrixInverse[,,istate] <- diag(nummod) # solve(diag(x)) == diag(x)
            }
            0
        }, error = function(err) {
            2
        })
        if (temp==2) {
            correlationMatrix[,,istate] <- diag(nummod)
            determinant[istate] <- 1
            correlationMatrixInverse[,,istate] <- diag(nummod)
        }
        if (any(is.na(correlationMatrixInverse[,,istate]))) {
            correlationMatrixInverse[,,istate] <- diag(nummod)
        }
    }
    remove(z.per.bin)
    stopTimedMessage(ptm)

#     ## Calculate emission densities (only debugging, do in C++ otherwise)
#     densities <- matrix(1, ncol=numstates2use, nrow=length(bincounts))
#     for (comb.state in comb.states) {
# print(comb.state)
#         istate <- which(comb.state==comb.states)
#         z.temp <- matrix(NA, nrow=length(bincounts), ncol=nummod)
#         bin.comb.state <- dec2bin(comb.state, colnames=info$ID)
#         product <- 1
#         for (imod in 1:nummod) {
#             z.temp[,imod] <- z.per.bin[,imod,bin.comb.state[imod]+1]
#             ind.modstate <- bin.comb.state[imod]+2
#             if (rownames(distributions[[imod]])[ind.modstate] == 'unmodified') {
#                 size <- distributions[[imod]][ind.modstate,'size']
#                 prob <- distributions[[imod]][ind.modstate,'prob']
#                 product <- product * dzinbinom(bincounts$counts[,imod,'0'], w=weights[[imod]]['zero-inflation'], size, prob)
#             } else if (rownames(distributions[[imod]])[ind.modstate] == 'modified') {
#                 size <- distributions[[imod]][ind.modstate,'size']
#                 prob <- distributions[[imod]][ind.modstate,'prob']
#                 product <- product * stats::dnbinom(bincounts$counts[,imod,'0'], size, prob)
#             }
#         }
#         exponent <- -0.5 * apply( ( z.temp %*% (correlationMatrixInverse[ , , istate] - diag(nummod)) ) * z.temp, 1, sum)
#         densities[,istate] <- product * determinant[istate]^(-0.5) * exp( exponent )
#     }

    # Return parameters
    out = list(info = info,
                bincounts = bincounts,
                bins = bins,
                comb.states = comb.states,
                use.states = use.states,
                distributions = distributions,
                weights = weights,
                correlationMatrix = correlationMatrix,
                correlationMatrixInverse = correlationMatrixInverse,
                determinant = determinant
    )
    return(out)
}


### Get real transition probabilities ###
transProbs <- function(model) {
    bins <- model$bins
    df <- data.frame(from = bins$combination[-length(bins)], to = bins$combination[-1])
    t <- table(df)
    t <- t[rownames(model$transitionProbs), colnames(model$transitionProbs)]
    t <- sweep(t, MARGIN = 1, STATS = rowSums(t), FUN = '/')
    return(t)
}
ataudt/chromstaR documentation built on Dec. 26, 2021, 12:07 a.m.