R/enrichmentAnalysis.R

Defines functions enrichmentAtAnnotation plotEnrichment plotEnrichCountHeatmap plotFoldEnrichHeatmap

Documented in enrichmentAtAnnotation plotEnrichCountHeatmap plotEnrichment plotFoldEnrichHeatmap

#' Enrichment analysis
#' 
#' Plotting functions for enrichment analysis of \code{\link{multiHMM}} or \code{\link{combinedMultiHMM}} objects with any annotation of interest, specified as a \code{\link[GenomicRanges]{GRanges-class}} object.
#' 
#' @name enrichment_analysis
#' @param combinations A vector with combinations for which the enrichment will be calculated. If \code{NULL} all combinations will be considered.
#' @param marks A vector with marks for which the enrichment is plotted. If \code{NULL} all marks will be considered.
#' @param logscale Set to \code{TRUE} to plot fold enrichment on log-scale. Ignored if \code{statistic = 'fraction'}.
#' @return A \code{\link[ggplot2:ggplot]{ggplot}} object containing the plot or a list() with \code{\link[ggplot2:ggplot]{ggplot}} objects if several plots are returned. For \code{plotFoldEnrichHeatmap} a named array with fold enrichments if \code{plot=FALSE}.
#' @author Aaron Taudt
#' @seealso \code{\link{plotting}}
#' @examples 
#'### Get an example multiHMM ###
#'file <- system.file("data","multivariate_mode-combinatorial_condition-SHR.RData",
#'                     package="chromstaR")
#'model <- get(load(file))
#'
#'### Obtain gene coordinates for rat from biomaRt ###
#'library(biomaRt)
#'ensembl <- useMart('ENSEMBL_MART_ENSEMBL', host='may2012.archive.ensembl.org',
#'                   dataset='rnorvegicus_gene_ensembl')
#'genes <- getBM(attributes=c('ensembl_gene_id', 'chromosome_name', 'start_position',
#'                            'end_position', 'strand', 'external_gene_id',
#'                            'gene_biotype'),
#'               mart=ensembl)
#'# Transform to GRanges for easier handling
#'genes <- GRanges(seqnames=paste0('chr',genes$chromosome_name),
#'                 ranges=IRanges(start=genes$start, end=genes$end),
#'                 strand=genes$strand,
#'                 name=genes$external_gene_id, biotype=genes$gene_biotype)
#'# Rename chrMT to chrM
#'seqlevels(genes)[seqlevels(genes)=='chrMT'] <- 'chrM'
#'print(genes)
#'
#'### Make the enrichment plots ###
#'# We expect promoter [H3K4me3] and bivalent-promoter signatures [H3K4me3+H3K27me3]
#'# to be enriched at transcription start sites.
#'    plotEnrichment(hmm = model, annotation = genes, bp.around.annotation = 15000) +
#'    ggtitle('Fold enrichment around genes') +
#'    xlab('distance from gene body')
#'  
#'# Plot enrichment only at TSS. We make use of the fact that TSS is the start of a gene.
#'    plotEnrichment(model, genes, region = 'start') +
#'    ggtitle('Fold enrichment around TSS') +
#'    xlab('distance from TSS in [bp]')
#'# Note: If you want to facet the plot because you have many combinatorial states you
#'# can do that with
#'    plotEnrichment(model, genes, region = 'start') +
#'    facet_wrap(~ combination)
#'  
#'# Another form of visualization that shows every TSS in a heatmap
#'# If transparency is not supported try to plot to pdf() instead.
#'    tss <- resize(genes, width = 3, fix = 'start')
#'    plotEnrichCountHeatmap(model, tss) +
#'    theme(strip.text.x = element_text(size=6))
#'  
#'# Fold enrichment with different biotypes, showing that protein coding genes are
#'# enriched with (bivalent) promoter combinations [H3K4me3] and [H3K4me3+H3K27me3],
#'# while rRNA is enriched with the empty [] and repressive combinations [H3K27me3].
#'    tss <- resize(genes, width = 3, fix = 'start')
#'    biotypes <- split(tss, tss$biotype)
#'    plotFoldEnrichHeatmap(model, annotations=biotypes) + coord_flip()
#'
NULL


#' @describeIn enrichment_analysis Compute the fold enrichment of combinatorial states for multiple annotations.
#' @param hmm A \code{\link{combinedMultiHMM}} or \code{\link{multiHMM}} object or a file that contains such an object.
#' @param annotations A \code{list()} with \code{\link{GRanges-class}} objects containing coordinates of multiple annotations The names of the list entries will be used to name the return values.
#' @param plot A logical indicating whether the plot or an array with the fold enrichment values is returned.
#' @param what One of \code{c('combinations','peaks','counts','transitions')} specifying on which feature the statistic is calculated.
#' @importFrom S4Vectors subjectHits queryHits
#' @importFrom IRanges subsetByOverlaps
#' @importFrom reshape2 melt
#' @export
plotFoldEnrichHeatmap <- function(hmm, annotations, what="combinations", combinations=NULL, marks=NULL, plot=TRUE, logscale=TRUE) {
    
    hmm <- loadHmmsFromFiles(hmm, check.class=c(class.multivariate.hmm, class.combined.multivariate.hmm))[[1]]
    ## Variables
    bins <- hmm$bins
    if (is(hmm,class.combined.multivariate.hmm)) {
    } else if (is(hmm,class.multivariate.hmm)) {
        # Rename 'combination' to 'combination.' for coherence with combinedMultiHMM
        names(mcols(bins))[grep('combination', names(mcols(bins)))] <- paste0('combination.', unique(hmm$info$condition))
    }
    conditions <- sub('combination.', '', grep('combination', names(mcols(bins)), value=TRUE))
    comb.levels <- levels(mcols(bins)[,paste0('combination.', conditions[1])])
    ## Create new column combination with all conditions combined
    combs <- list()
    for (condition in conditions) {
        combs[[condition]] <- paste0(condition, ":", mcols(bins)[,paste0('combination.', condition)])
    }
    combs$sep <- ', '
    bins$transitions <- factor(do.call(paste, combs))
    if (is.null(combinations)) {
        comb.levels <- levels(mcols(bins)[,paste0('combination.', conditions[1])])
    } else {
        comb.levels <- combinations
    }
    if (is.null(marks)) {
        mark.levels <- unique(hmm$info$mark)
    } else {
        mark.levels <- marks
    }
    genome <- sum(as.numeric(width(bins)))
    annotationsAtBins <- lapply(annotations, function(x) { IRanges::subsetByOverlaps(x, bins) })
    feature.lengths <- lapply(annotationsAtBins, function(x) { sum(as.numeric(width(x))) })
    
    ggplts <- list()
    folds <- list()
    maxfolds <- list()
    minfolds <- list()
    
    if (what == 'peaks') {
        binstates <- dec2bin(bins$state, colnames=hmm$info$ID)
    }
    if (what == 'transitions') {
        bins$combination <- bins$transitions
        comb.levels <- names(sort(table(bins$combination), decreasing = TRUE))
        fold <- array(NA, dim=c(length(annotationsAtBins), length(comb.levels)), dimnames=list(annotation=names(annotationsAtBins), combination=comb.levels))
        for (icomb in 1:length(comb.levels)) {
            mask <- bins$combination == comb.levels[icomb]
            bins.mask <- bins[mask]
            combstate.length <- sum(as.numeric(width(bins.mask)))
            for (ifeat in 1:length(annotationsAtBins)) {
                feature <- annotationsAtBins[[ifeat]]
                ind <- findOverlaps(bins.mask, feature)

                binsinfeature <- bins.mask[unique(S4Vectors::queryHits(ind))]
                sum.binsinfeature <- sum(as.numeric(width(binsinfeature)))

                featuresinbins <- feature[unique(S4Vectors::subjectHits(ind))]
                sum.featuresinbins <- sum(as.numeric(width(featuresinbins)))

                fold[ifeat,icomb] <- sum.binsinfeature / combstate.length / feature.lengths[[ifeat]] * genome
            }
        }
        if (plot) {
            df <- reshape2::melt(fold, value.name='foldEnrichment')
            if (logscale) {
                df$foldEnrichment <- log(df$foldEnrichment)
            }
            foldsnoinf <- setdiff(df$foldEnrichment, c(Inf, -Inf))
            maxfolds[[1]] <- max(foldsnoinf, na.rm=TRUE)
            minfolds[[1]] <- min(foldsnoinf, na.rm=TRUE)
            ggplt <- ggplot(df) + geom_tile(aes_string(x='combination', y='annotation', fill='foldEnrichment'))
            ggplt <- ggplt + theme(axis.text.x = element_text(angle=90, hjust=1, vjust=0.5))
            ggplts[[1]] <- ggplt
        } else {
            folds[[1]] <- fold
        }
    }
    
    ## Fold enrichment
    if (what %in% c('peaks', 'combinations')) {
        for (condition in conditions) {
            if (what == 'combinations') {
                bins$combination <- mcols(bins)[,paste0('combination.', condition)]
                fold <- array(NA, dim=c(length(annotationsAtBins), length(comb.levels)), dimnames=list(annotation=names(annotationsAtBins), combination=comb.levels))
                for (icomb in 1:length(comb.levels)) {
                    mask <- bins$combination == comb.levels[icomb]
                    bins.mask <- bins[mask]
                    combstate.length <- sum(as.numeric(width(bins.mask)))
                    for (ifeat in 1:length(annotationsAtBins)) {
                        feature <- annotationsAtBins[[ifeat]]
                        ind <- findOverlaps(bins.mask, feature)
        
                        binsinfeature <- bins.mask[unique(S4Vectors::queryHits(ind))]
                        sum.binsinfeature <- sum(as.numeric(width(binsinfeature)))
        
                        featuresinbins <- feature[unique(S4Vectors::subjectHits(ind))]
                        sum.featuresinbins <- sum(as.numeric(width(featuresinbins)))
        
                        fold[ifeat,icomb] <- sum.binsinfeature / combstate.length / feature.lengths[[ifeat]] * genome
                    }
                }
        
            } else if (what == 'peaks') {
                fold <- array(NA, dim=c(length(annotationsAtBins), length(mark.levels)), dimnames=list(annotation=names(annotationsAtBins), mark=mark.levels))
                for (imark in 1:length(mark.levels)) {
                    mark <- mark.levels[imark]
                    colmask <- hmm$info$mark == mark
                    colmask <- colmask & (!duplicated(paste0(hmm$info$mark, hmm$info$condition))) # remove replicates
                    if (condition != "") {
                        colmask <- colmask & (hmm$info$condition == condition)
                    }
                    binstates.cond <- binstates[,colmask]
                    mask <- binstates.cond
                    bins.mask <- bins[mask]
                    combstate.length <- sum(as.numeric(width(bins.mask)))
                    for (ifeat in 1:length(annotationsAtBins)) {
                        feature <- annotationsAtBins[[ifeat]]
                        ind <- findOverlaps(bins.mask, feature)
        
                        binsinfeature <- bins.mask[unique(S4Vectors::queryHits(ind))]
                        sum.binsinfeature <- sum(as.numeric(width(binsinfeature)))
        
                        featuresinbins <- feature[unique(S4Vectors::subjectHits(ind))]
                        sum.featuresinbins <- sum(as.numeric(width(featuresinbins)))
        
                        fold[ifeat,imark] <- sum.binsinfeature / combstate.length / feature.lengths[[ifeat]] * genome
                    }
                }
              
            }
          
            if (plot) {
                df <- reshape2::melt(fold, value.name='foldEnrichment')
                if (logscale) {
                    df$foldEnrichment <- log(df$foldEnrichment)
                }
                foldsnoinf <- setdiff(df$foldEnrichment, c(Inf, -Inf))
                maxfolds[[condition]] <- max(foldsnoinf, na.rm=TRUE)
                minfolds[[condition]] <- min(foldsnoinf, na.rm=TRUE)
                if (what == 'combinations') {
                    ggplt <- ggplot(df) + geom_tile(aes_string(x='combination', y='annotation', fill='foldEnrichment'))
                } else if (what == 'peaks') {
                    ggplt <- ggplot(df) + geom_tile(aes_string(x='mark', y='annotation', fill='foldEnrichment'))
                }
                ggplt <- ggplt + theme(axis.text.x = element_text(angle=90, hjust=1, vjust=0.5))
                ggplts[[condition]] <- ggplt
            } else {
                folds[[condition]] <- fold
            }
        }
    }
    if (plot) {
        maxfold <- max(unlist(maxfolds), na.rm=TRUE)
        minfold <- min(unlist(minfolds), na.rm=TRUE)
        limits <- max(abs(maxfold), abs(minfold))
        if (logscale) {
            # ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_fill_gradientn(name='log(observed/expected)', colors=grDevices::colorRampPalette(c("blue","white","red"))(20), values=c(seq(-limits,0,length.out=10), seq(0,limits,length.out=10)), rescaler=function(x,...) {x}, oob=identity, limits=c(-limits, limits)) }) # broke in ggplot2 3.0.0
            ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_fill_gradientn(name='log(observed/expected)', colors=grDevices::colorRampPalette(c("blue","white","red"))(20), values=c(seq(-limits,0,length.out=10), seq(0,limits,length.out=10)), rescaler=function(x,...) {x}, limits=c(-limits, limits)) })
            ggplts <- lapply(ggplts, function(ggplt) { ggplt$data$foldEnrichment[ggplt$data$foldEnrichment == -Inf] <- -limits; ggplt })
        } else {
            # ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_fill_gradientn(name='observed/expected', colors=grDevices::colorRampPalette(c("blue","white","red"))(20), values=c(seq(0,1,length.out=10), seq(1,maxfold,length.out=10)), rescaler=function(x,...) {x}, oob=identity, limits=c(0,maxfold)) }) # broke in ggplot2 3.0.0
            ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_fill_gradientn(name='observed/expected', colors=grDevices::colorRampPalette(c("blue","white","red"))(20), values=c(seq(0,1,length.out=10), seq(1,maxfold,length.out=10)), rescaler=function(x,...) {x}, limits=c(0,maxfold)) })
        }
    }

    if (is(hmm,class.multivariate.hmm) | what == 'transitions') {
        if (plot) {
            return(ggplts[[1]])
        } else {
            return(folds[[1]])
        }
    } else if (is(hmm,class.combined.multivariate.hmm)) {
        if (plot) {
            return(ggplts)
        } else {
            return(folds)
        }
    }
    
}


#' @describeIn enrichment_analysis Plot read counts around annotation as heatmap.
#' @inheritParams enrichmentAtAnnotation
#' @param max.rows An integer specifying the number of randomly subsampled rows that are plotted from the \code{annotation} object. This is necessary to avoid crashing for heatmaps with too many rows.
#' @param colorByCombinations A logical indicating whether or not to color the heatmap by combinations.
#' @param sortByCombinations A logical indicating whether or not to sort the heatmap by combinations.
#' @param sortByColumns An integer vector specifying the column numbers by which to sort the rows. If \code{sortByColumns} is specified, will force \code{sortByCombinations=FALSE}.
#' @importFrom reshape2 melt
#' @importFrom IRanges subsetByOverlaps
#' @export
plotEnrichCountHeatmap <- function(hmm, annotation, bp.around.annotation=10000, max.rows=1000, combinations=NULL, colorByCombinations=sortByCombinations, sortByCombinations=is.null(sortByColumns), sortByColumns=NULL) {

    if (!is.null(sortByColumns)) {
        sortByCombinations <- FALSE
    }
  
    hmm <- loadHmmsFromFiles(hmm, check.class=c(class.multivariate.hmm, class.combined.multivariate.hmm))[[1]]
    ## Variables
    bins <- hmm$bins
    if (is(hmm,class.combined.multivariate.hmm)) {
        conditions <- sub('combination.', '', grep('combination', names(mcols(bins)), value=TRUE))
        comb.levels <- levels(mcols(bins)[,paste0('combination.', conditions[1])])
        ## Create new column combination with all conditions combined
        combs <- list()
        for (condition in conditions) {
            combs[[condition]] <- paste0(condition, ":", mcols(bins)[,paste0('combination.', condition)])
        }
        combs$sep <- ', '
        bins$combination <- factor(do.call(paste, combs))
    } else if (is(hmm,class.multivariate.hmm)) {
        comb.levels <- levels(bins$combination)
    }
    binsize <- width(bins)[1]
    around <- round(bp.around.annotation/binsize)
    
    ## Get RPKM values
    # bins$counts <- rpkm.matrix(bins$counts, binsize = binsize)

    # Subsampling for plotting of huge data.frames
    annotation <- IRanges::subsetByOverlaps(annotation, bins)
    if (length(annotation)>max.rows) {
        warning("Subsampled the data to 'max.rows=", max.rows, "'. Set 'max.rows=Inf' to turn this off, but be aware that plotting might take a very long time.")
        annotation <- sample(annotation, size=max.rows, replace=FALSE)
    }
  
    # Get bins that overlap the start of annotation
    ptm <- startTimedMessage("Overlaps with annotation ...")
    annotation$index <- NA
    mask.plus <- as.logical(strand(annotation)=='+' | strand(annotation)=='*')
    index.plus <- findOverlaps(annotation[mask.plus], bins, select="first")
    annotation$index[mask.plus] <- index.plus
    mask.minus <- as.logical(strand(annotation)=='-')
    index.minus <- findOverlaps(annotation[mask.minus], bins, select="last")
    annotation$index[mask.minus] <- index.minus
    stopTimedMessage(ptm)
    
    # Get surrounding indices
    ptm <- startTimedMessage("Getting surrounding indices ...")
    seq.around <- seq(-around, around, 1)
    ext.index <- array(NA, dim=c(length(annotation), length(seq.around)), dimnames=list(anno=1:length(annotation), position=binsize*seq.around))
    for (icol in 1:ncol(ext.index)) {
        shift.index <- seq.around[icol]
        ext.index[mask.plus, icol] <- annotation$index[mask.plus] + shift.index
        ext.index[mask.minus, icol] <- annotation$index[mask.minus] - shift.index
    }
    ext.index[ext.index <= 0] <- NA
    ext.index[ext.index > length(bins)] <- NA
    annotation$ext.index <- ext.index
    stopTimedMessage(ptm)
    
    ## Add combination to annotation
    annotation$combination <- bins$combination[annotation$index]
    
    ## Go through combinations and then IDs to get the read counts
    ptm <- startTimedMessage("Getting read counts")
    counts <- list()
    if (is.null(combinations)) {
        # # Sort by abundance
        # comb.levels <- names(sort(table(bins$combination[index]), decreasing = TRUE))
        comb.levels <- levels(bins$combination)
    } else {
        comb.levels <- combinations
    }
    for (combination in comb.levels) {
        counts[[combination]] <- list()
        ext.index.combination <- annotation$ext.index[annotation$combination == combination,, drop=FALSE]
        for (nID in colnames(bins$counts.rpkm)) {
            counts[[combination]][[nID]] <- array(bins$counts.rpkm[ext.index.combination,nID], dim=dim(ext.index.combination), dimnames=dimnames(ext.index.combination))
        }
    }
    stopTimedMessage(ptm)
    
    ## Prepare data.frame
    ptm <- startTimedMessage("Making the plot ...")
    # # Exclude rare combinations for plotting
    # if (is.null(combinations)) {
    #     num.comb <- sapply(counts, function(x) { nrow(x[[1]]) })
    #     comb2keep <- names(num.comb)[num.comb/sum(num.comb) > 0.005]
    #     counts <- counts[comb2keep]
    # }
    df <- reshape2::melt(counts)
    names(df) <- c('id','position','RPKM','ID','combination')
    # Unsorted rows
    df$id <- factor(df$id, levels=rev(dimnames(annotation$ext.index)[['anno']]))
    if (sortByCombinations) {
        df$id <- factor(df$id, levels=rev(unique(df$id)))
    }
    if (!is.null(sortByColumns)) {
        counts2order <- bins$counts.rpkm[annotation$index, sortByColumns, drop=FALSE]
        l <- as.data.frame(counts2order)
        names(l) <- NULL # remove names to make the do.call(order, ...) safe (see ?order)
        l$decreasing <- FALSE
        ordr <- do.call(order, l)
        df$id <- factor(df$id, levels=dimnames(annotation$ext.index)[['anno']][ordr])
    }
    df$combination <- factor(df$combination, levels=comb.levels)
    df$ID <- factor(df$ID, levels=hmm$info$ID)
    
    ## Theme
    custom_theme <- theme(
        panel.grid = element_blank(),
        panel.border = element_rect(fill='NA'),
        panel.background = element_rect(fill='white'),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        axis.line.y = element_blank()
    )
    ## Plot as heatmap
    ggplt <- ggplot(df, mapping=aes_string(x='position', y='id'))
    if (colorByCombinations) {
        ggplt <- ggplt + geom_tile(aes_string(color='combination'))
        ggplt <- ggplt + scale_color_manual(values = getDistinctColors(levels(df$combination)))
    } else {
        ggplt <- ggplt + geom_tile()
    }
    ggplt <- ggplt + geom_tile(aes_string(x='position', y='id', fill='RPKM'), alpha=0.6)
    ggplt <- ggplt + facet_wrap( ~ ID, nrow=1) + custom_theme
    ggplt <- ggplt + xlab('distance from annotation in [bp]') + ylab('')
    ggplt <- ggplt + scale_fill_continuous(trans='log1p', low='white', high='black')
    if (sortByCombinations) {
        # Insert horizontal lines
        y.lines <- sapply(split(df$id, df$combination), function(x) { 
          y <- -Inf
          if (length(x)>0) {
            y <- max(as.integer(x))
          }
          return(y)
        })
        df.lines <- data.frame(y=sort(y.lines[-1]) + 0.5)
        ggplt <- ggplt + geom_hline(data=df.lines, mapping=aes_string(yintercept='y'), linetype=2)
    }
    # Increase color size in legend
    ggplt <- ggplt + guides(color=guide_legend(override.aes = list(size=2)))
    stopTimedMessage(ptm)
    
    return(ggplt)
}


#' @describeIn enrichment_analysis Plot fold enrichment of combinatorial states around and inside of annotation.
#' @inheritParams enrichmentAtAnnotation
#' @importFrom reshape2 melt
#' @export
plotEnrichment <- function(hmm, annotation, bp.around.annotation=10000, region=c("start","inside","end"), num.intervals=20, what='combinations', combinations=NULL, marks=NULL, statistic='fold', logscale=TRUE) {

    ## Check user input
    if ((!what %in% c('combinations','peaks','counts')) | length(what) > 1) {
        stop("argument 'what' must be one of c('combinations','peaks','counts')")
    }
    if (!is.null(marks) & what!='peaks') {
        stop("Please set argument 'what=\"peaks\"' if you want to plot marks instead of combinations.")
    }
  
    ## Variables
    hmm <- loadHmmsFromFiles(hmm, check.class=c(class.univariate.hmm, class.multivariate.hmm, class.combined.multivariate.hmm))[[1]]
    bins <- hmm$bins
    if (is(hmm,class.univariate.hmm)) {
        # bins$counts <- rpkm.vector(hmm$bins$counts, binsize=mean(width(hmm$bins)))
        mcols(bins)['combination.'] <- bins$state
        bins$state <- c('zero-inflation' = 0, 'unmodified' = 0, 'modified' = 1)[bins$state]
        hmm$info <- data.frame(file=NA, mark=1, condition=1, replicate=1, pairedEndReads=NA, controlFiles=NA, ID='1-1-rep1')
    } else if (is(hmm,class.combined.multivariate.hmm)) {
        # bins$counts <- rpkm.matrix(hmm$bins$counts, binsize=mean(width(hmm$bins)))
    } else if (is(hmm,class.multivariate.hmm)) {
        # bins$counts <- rpkm.matrix(hmm$bins$counts, binsize=mean(width(hmm$bins)))
        # Rename 'combination' to 'combination.' for coherence with combinedMultiHMM
        names(mcols(bins))[grep('combination', names(mcols(bins)))] <- 'combination.'
    }
    conditions <- sub('combination.', '', grep('combination', names(mcols(bins)), value=TRUE))
    if (is.null(combinations)) {
        comb.levels <- levels(mcols(bins)[,paste0('combination.', conditions[1])])
    } else {
        comb.levels <- combinations
    }
    if (is.null(marks)) {
        mark.levels <- unique(hmm$info$mark)
    } else {
        mark.levels <- marks
    }

    if (what %in% c('peaks','counts')) {
        ### Get fold enrichment
        enrich <- enrichmentAtAnnotation(bins, hmm$info, annotation, bp.around.annotation=bp.around.annotation, region=region, what=what, num.intervals=num.intervals, statistic=statistic)
    }
    ggplts <- list()
    maxfolds <- list()
    minfolds <- list()
    maxcols <- list()
    for (condition in conditions) {
        if (what == 'combinations') {
            ### Get fold enrichment
            bins$combination <- mcols(bins)[,paste0('combination.', condition)]
            enrich.cond <- enrichmentAtAnnotation(bins, hmm$info, annotation, bp.around.annotation=bp.around.annotation, region=region, what=what, num.intervals=num.intervals, statistic=statistic)
        } else {
            enrich.cond <- enrich
        }
        ### Prepare for plotting
        df <- reshape2::melt(enrich.cond)
        df$L1 <- factor(df$L1, levels=c('start','inside','end'))
        df <- rbind(df[df$L1 == 'start',], df[df$L1 == 'inside',], df[df$L1 == 'end',])
        if (length(region)>=2 & 'inside' %in% region) {
            df <- df[!(df$L1 == 'start' & df$lag > 0),]
            df <- df[!(df$L1 == 'end' & df$lag < 0),]
            df$position <- apply(data.frame(df$interval, df$lag), 1, max, na.rm = TRUE)
        } else if (length(region)==2 & ! 'inside' %in% region) {
            df <- df[!(df$L1 == 'start' & df$lag > 0),]
            df <- df[!(df$L1 == 'end' & df$lag < 0),]
            df$position <- df$lag
        } else if (length(region)==1) {
            df <- df[df$L1 == region,]
            df$position <- df$lag
        }
        if ('inside' %in% region) {
            df$position[df$L1 == 'end'] <- df$position[df$L1 == 'end'] + bp.around.annotation
        }
        df$position[df$L1 == 'inside'] <- df$position[df$L1 == 'inside'] * bp.around.annotation
        if (what == 'combinations') {
            df <- df[df$combination %in% comb.levels,]
            df$combination <- factor(df$combination, levels=comb.levels)
        } else if (what %in% c('peaks','counts')) {
            df$mark <- sub("-.*", "", df$ID)
            df <- df[df$mark %in% mark.levels, ]
            df$condition <- sapply(strsplit(as.character(df$ID), '-'), '[[', 2)
            if (condition != "") {
                df <- df[df$condition == condition, ]
            }
        }
        if (logscale & !(statistic=='fraction')) {
            df$value <- log(df$value)
        }

        ### Plot
        if (what == 'combinations') {
            ggplt <- ggplot(df) + geom_line(aes_string(x='position', y='value', col='combination'), size=2)
            if (statistic == 'fold') {
                if (logscale) {
                    ggplt <- ggplt + ylab('log(observed/expected)')
                    ggplt <- ggplt + geom_hline(yintercept=0, lty=2)
                } else {
                    ggplt <- ggplt + ylab('observed/expected')
                    ggplt <- ggplt + geom_hline(yintercept=1, lty=2)
                }
            } else if (statistic == 'fraction') {
                ggplt <- ggplt + ylab('fraction')
            }
            # ggplt <- ggplt + scale_color_manual(values = getDistinctColors(length(unique(df$combination))))
            maxcols[[condition]] <- length(unique(df$combination))
        } else if (what == 'peaks') {
            ggplt <- ggplot(df) + geom_line(aes_string(x='position', y='value', col='mark'), size=2)
            if (statistic == 'fold') {
                if (logscale) {
                    ggplt <- ggplt + ylab('log(observed/expected)')
                    ggplt <- ggplt + geom_hline(yintercept=0, lty=2)
                } else {
                    ggplt <- ggplt + ylab('observed/expected')
                    ggplt <- ggplt + geom_hline(yintercept=1, lty=2)
                }
            } else if (statistic == 'fraction') {
                ggplt <- ggplt + ylab('fraction')
            }
            # ggplt <- ggplt + scale_color_manual(values = getDistinctColors(length(unique(df$mark))))
            maxcols[[condition]] <- length(unique(df$mark))
        } else if (what == 'counts') {
            ggplt <- ggplot(df) + geom_line(aes_string(x='position', y='value', col='ID'), size=2)
            if (logscale) {
                ggplt <- ggplt + ylab('log(RPKM)')
            } else {
                ggplt <- ggplt + ylab('RPKM')
            }
            # ggplt <- ggplt + scale_color_manual(values = getDistinctColors(length(unique(df$ID))))
            maxcols[[condition]] <- length(unique(df$ID))
        }
        ggplt <- ggplt + theme_bw() + xlab('distance from annotation in [bp]')
        if (length(region)>=2 & 'inside' %in% region) {
            breaks <- c(c(-1, -0.5, 0, 0.5, 1, 1.5, 2) * bp.around.annotation)
            labels <- c(-bp.around.annotation, -bp.around.annotation/2, '0%', '50%', '100%', bp.around.annotation/2, bp.around.annotation)
            ggplt <- ggplt + scale_x_continuous(breaks=breaks, labels=labels)
        }
        foldsnoinf <- setdiff(df$value, c(Inf, -Inf))
        maxfolds[[condition]] <- max(foldsnoinf, na.rm=TRUE)
        minfolds[[condition]] <- min(foldsnoinf, na.rm=TRUE)
        ggplts[[condition]] <- ggplt
    }
    maxfold <- max(unlist(maxfolds), na.rm=TRUE)
    minfold <- min(unlist(minfolds), na.rm=TRUE)
    maxcol <- max(unlist(maxcols), na.rm=TRUE)
    if (statistic == 'fraction' & what %in% c('combinations','peaks')) {
        ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_y_continuous(limits=c(0,1)) })
    } else {
        ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_y_continuous(limits=c(minfold*(1-sign(minfold)*0.1),maxfold*(1+sign(maxfold)*0.1))) })
    }
    ggplts <- lapply(ggplts, function(ggplt) { ggplt + scale_color_manual(values=getDistinctColors(maxcol)) }) # Add color here like this because of weird bug
    if (is(hmm,class.univariate.hmm)) {
        return(ggplts[[1]])
    } else if (is(hmm,class.multivariate.hmm)) {
        return(ggplts[[1]])
    } else if (is(hmm,class.combined.multivariate.hmm)) {
        return(ggplts)
    }
    
}


#' Enrichment of (combinatorial) states for genomic annotations
#'
#' The function calculates the enrichment of a genomic feature with peaks or combinatorial states. Input is a \code{\link{multiHMM}} object (containing the peak calls and combinatorial states) and a \code{\link{GRanges-class}} object containing the annotation of interest (e.g. transcription start sites or genes).
#'
#' @author Aaron Taudt
#' @param bins The \code{$bins} entry from a \code{\link{multiHMM}} or \code{\link{combinedMultiHMM}} object.
#' @param info The \code{$info} entry from a \code{\link{multiHMM}} or \code{\link{combinedMultiHMM}} object.
#' @param annotation A \code{\link{GRanges-class}} object with the annotation of interest.
#' @param bp.around.annotation An integer specifying the number of basepairs up- and downstream of the annotation for which the enrichment will be calculated.
#' @param region A combination of \code{c('start','inside','end')} specifying the region of the annotation for which the enrichment will be calculated. Select \code{'start'} if you have a point-sized annotation like transcription start sites. Select \code{c('start','inside','end')} if you have long annotations like genes.
#' @param what One of \code{c('combinations','peaks','counts')} specifying on which feature the statistic is calculated.
#' @param num.intervals Number of intervals for enrichment 'inside' of annotation.
#' @param statistic The statistic to calculate. Either 'fold' for fold enrichments or 'fraction' for fraction of bins falling into the annotation.
#' @return A \code{list()} containing \code{data.frame()}s for enrichment of combinatorial states and binary states at the start, end and inside of the annotation.
#' @importFrom S4Vectors as.factor subjectHits queryHits
enrichmentAtAnnotation <- function(bins, info, annotation, bp.around.annotation=10000, region=c('start','inside','end'), what='combinations', num.intervals=21, statistic='fold') {

    ## Check user input
    if ((!what %in% c('combinations','peaks','counts')) | length(what) > 1) {
        stop("argument 'what' must be one of c('combinations','peaks','counts')")
    }
    seqlevels.only.in.bins <- setdiff(seqlevels(bins), seqlevels(annotation))
    seqlevels.only.in.annotation <- setdiff(seqlevels(annotation), seqlevels(bins))
    if (length(seqlevels.only.in.bins) > 0 | length(seqlevels.only.in.annotation) > 0) {
        warning("Sequence levels in 'bins' but not in 'annotation': ", paste0(seqlevels.only.in.bins, collapse = ', '), "\n  Sequence levels in 'annotation' but not in 'bins': ", paste0(seqlevels.only.in.annotation, collapse = ', '))
    }
  
    ## Variables
    binsize <- width(bins)[1]
    lag <- round(bp.around.annotation/binsize)
    enrich <- list()
    enrich$combinations <- list()
    enrich$peaks <- list()
    enrich$counts <- list()
    info.dedup <- info[!duplicated(paste0(info$mark, info$condition)), ]

    ## Get combinatorial and binary states
    combinations <- bins$combination
    tcombinations <- table(combinations)
    if ('peaks' %in% what) {
        binstates <- dec2bin(bins$state, colnames=info$ID)
        # Remove replicates
        if (ncol(binstates) > 1) {
            binstates <- binstates[ ,info.dedup$ID]
        } else {
            binstates <- matrix(binstates[ ,info.dedup$ID], ncol=1)
        }
        colsums.binstates <- colSums(binstates)
    }
    if ('counts' %in% what) {
        counts <- bins$counts.rpkm
    }
    
    ### Calculating enrichment inside of annotation ###
    if ('inside' %in% region) {
        ptm <- startTimedMessage("Enrichment inside of annotations ...")

        intervals <- seq(from=0, to=1, length.out=num.intervals+1)
        widths.annotation <- width(annotation) - 1
        annotation.1bp <- resize(annotation, 1, fix='start')
        # Initialize arrays
        if ('peaks' %in% what) binstates.inside <- array(dim=c(num.intervals+1, length(info.dedup$ID)), dimnames=list(interval=intervals, ID=info.dedup$ID))
        if ('combinations' %in% what) combinations.inside <- array(dim=c(num.intervals+1, length(levels(bins$combination))), dimnames=list(interval=intervals, combination=levels(bins$combination)))
        if ('counts' %in% what) counts.inside <- array(dim=c(num.intervals+1, length(info$ID)), dimnames=list(interval=intervals, ID=info$ID))

        for (interval in intervals) {
            shift <- widths.annotation * interval * c(1,-1,1)[as.integer(strand(annotation))]
            shifted.starts <- start(annotation.1bp) + shift
            annotation.shifted <- GRanges(seqnames = seqnames(annotation.1bp), ranges = IRanges(start = shifted.starts, end = shifted.starts), strand = strand(annotation.1bp))
            # Get bins that overlap the shifted annotation
            index.inside.plus <- suppressWarnings( findOverlaps(annotation.shifted[strand(annotation.shifted)=='+' | strand(annotation.shifted)=='*'], bins, select="first") )
            index.inside.minus <- suppressWarnings( findOverlaps(annotation.shifted[strand(annotation.shifted)=='-'], bins, select="last") )
            index.inside.plus <- index.inside.plus[!is.na(index.inside.plus)]
            index.inside.minus <- index.inside.minus[!is.na(index.inside.minus)]
            index <- c(index.inside.plus, index.inside.minus)
            index <- index[index>0 & index<=length(bins)] # index could cross chromosome boundaries, but we risk it
            if ('peaks' %in% what) {
                if (ncol(binstates) > 1) {
                    binstates.index <- binstates[index,]
                } else {
                    binstates.index <- matrix(binstates[index,], ncol=1)
                }
                if (statistic == 'fraction') {
                    binstates.inside[as.character(interval),] <- colSums(binstates.index) / length(index) # or colMeans
                } else if (statistic == 'fold') {
                    binstates.inside[as.character(interval),] <- colSums(binstates.index) / length(index) / colsums.binstates * length(bins)
                }
            }
            if ('combinations' %in% what) {
                if (statistic == 'fraction') {
                    fold <- table(combinations[index]) / length(index)
                } else if (statistic == 'fold') {
                    fold <- table(combinations[index]) / length(index) / tcombinations * length(bins) # fold enrichment
                }
                fold[is.na(fold)] <- 0
                combinations.inside[as.character(interval),] <- fold
            }
            if ('counts' %in% what) {
                counts.inside[as.character(interval),] <- colMeans(counts[index,])
            }
        }
        if ('peaks' %in% what) {
            enrich$peaks$inside <- binstates.inside
        }
        if ('combinations' %in% what) {
            enrich$combinations$inside <- combinations.inside
        }
        if ('counts' %in% what) {
            enrich$counts$inside <- counts.inside
        }
        stopTimedMessage(ptm)
    }

    ### 10000 bp before annotation ###
    if ('start' %in% region) {
        ptm <- startTimedMessage("Enrichment ",bp.around.annotation,"bp before annotations")
        # Get bins that overlap the start of annotation
        index.start.plus <- suppressWarnings( findOverlaps(annotation[strand(annotation)=='+' | strand(annotation)=='*'], bins, select="first") )
        index.start.minus <- suppressWarnings( findOverlaps(annotation[strand(annotation)=='-'], bins, select="last") )
        index.start.plus <- index.start.plus[!is.na(index.start.plus)]
        index.start.minus <- index.start.minus[!is.na(index.start.minus)]
        # Occurrences at every bin position relative to feature
        if ('peaks' %in% what) binstates.start <- array(dim=c(length(-lag:lag), length(info.dedup$ID)), dimnames=list(lag=-lag:lag, ID=info.dedup$ID))
        if ('combinations' %in% what) combinations.start <- array(dim=c(length(-lag:lag), length(levels(bins$combination))), dimnames=list(lag=-lag:lag, combination=levels(bins$combination)))
        if ('counts' %in% what) counts.start <- array(dim=c(length(-lag:lag), length(info$ID)), dimnames=list(lag=-lag:lag, ID=info$ID))
        for (ilag in -lag:lag) {
            index <- c(index.start.plus+ilag, index.start.minus-ilag)
            index <- index[index>0 & index<=length(bins)]
            if ('peaks' %in% what) {
                if (ncol(binstates) > 1) {
                    binstates.index <- binstates[index,]
                } else {
                    binstates.index <- matrix(binstates[index,], ncol=1)
                }
                if (statistic == 'fraction') {
                    binstates.start[as.character(ilag),] <- colSums(binstates.index) / length(index)
                } else if (statistic == 'fold') {
                    binstates.start[as.character(ilag),] <- colSums(binstates.index) / length(index) / colsums.binstates * length(bins)
                }
            }
            if ('combinations' %in% what) {
                if (statistic == 'fraction') {
                    fold <- table(combinations[index]) / length(index)
                } else if (statistic == 'fold') {
                    fold <- table(combinations[index]) / length(index) / tcombinations * length(bins) # fold enrichment
                }
                fold[is.na(fold)] <- 0
                combinations.start[as.character(ilag),] <- fold
            }
            if ('counts' %in% what) {
                counts.start[as.character(ilag),] <- colMeans(counts[index,])
            }
        }
        if ('peaks' %in% what) {
            rownames(binstates.start) <- as.numeric(rownames(binstates.start)) * binsize
            enrich$peaks$start <- binstates.start
        }
        if ('combinations' %in% what) {
            rownames(combinations.start) <- as.numeric(rownames(combinations.start)) * binsize
            enrich$combinations$start <- combinations.start
        }
        if ('counts' %in% what) {
            rownames(counts.start) <- as.numeric(rownames(counts.start)) * binsize
            enrich$counts$start <- counts.start
        }
        stopTimedMessage(ptm)
    }

    ### 10000 bp after annotation ###
    if ('end' %in% region) {
        ptm <- startTimedMessage("Enrichment ",bp.around.annotation,"bp after annotations")
        # Get bins that overlap the end of annotation
        index.end.plus <- suppressWarnings( findOverlaps(annotation[strand(annotation)=='+' | strand(annotation)=='*'], bins, select="last") )
        index.end.minus <- suppressWarnings( findOverlaps(annotation[strand(annotation)=='-'], bins, select="first") )
        index.end.plus <- index.end.plus[!is.na(index.end.plus)]
        index.end.minus <- index.end.minus[!is.na(index.end.minus)]
        # Occurrences at every bin position relative to feature
        if ('peaks' %in% what) binstates.end <- array(dim=c(length(-lag:lag), length(info.dedup$ID)), dimnames=list(lag=-lag:lag, ID=info.dedup$ID))
        if ('combinations' %in% what) combinations.end <- array(dim=c(length(-lag:lag), length(levels(bins$combination))), dimnames=list(lag=-lag:lag, combination=levels(bins$combination)))
        if ('counts' %in% what) counts.end <- array(dim=c(length(-lag:lag), length(info$ID)), dimnames=list(lag=-lag:lag, ID=info$ID))
        for (ilag in -lag:lag) {
            index <- c(index.end.plus+ilag, index.end.minus-ilag)
            index <- index[index>0 & index<=length(bins)]
            if ('peaks' %in% what) {
                if (ncol(binstates) > 1) {
                    binstates.index <- binstates[index,]
                } else {
                    binstates.index <- matrix(binstates[index,], ncol=1)
                }
                if (statistic == 'fraction') {
                    binstates.end[as.character(ilag),] <- colSums(binstates.index) / length(index)
                } else if (statistic == 'fold') {
                    binstates.end[as.character(ilag),] <- colSums(binstates.index) / length(index) / colsums.binstates * length(bins)
                }
            }
            if ('combinations' %in% what) {
                if (statistic == 'fraction') {
                    fold <- table(combinations[index]) / length(index)
                } else if (statistic == 'fold') {
                    fold <- table(combinations[index]) / length(index) / tcombinations * length(bins) # fold enrichment
                }
                fold[is.na(fold)] <- 0
                combinations.end[as.character(ilag),] <- fold
            }
            if ('counts' %in% what) {
                counts.end[as.character(ilag),] <- colMeans(counts[index,])
            }
        }
        if ('peaks' %in% what) {
            rownames(binstates.end) <- as.numeric(rownames(binstates.end)) * binsize
            enrich$peaks$end <- binstates.end
        }
        if ('combinations' %in% what) {
            rownames(combinations.end) <- as.numeric(rownames(combinations.end)) * binsize
            enrich$combinations$end <- combinations.end
        }
        if ('counts' %in% what) {
            rownames(counts.end) <- as.numeric(rownames(counts.end)) * binsize
            enrich$counts$end <- counts.end
        }
        stopTimedMessage(ptm)
    }

    return(enrich[[what]])

}

Try the chromstaR package in your browser

Any scripts or data that you put into this service are public.

chromstaR documentation built on Nov. 8, 2020, 8:29 p.m.