R/DrawSummitHeatmaps.R

Defines functions DrawSummitHeatmaps

Documented in DrawSummitHeatmaps

#' @title DrawSummitHeatmaps
#'
#' @description Plots heatmaps of pre-calculated read counts around the centers of provided regions of interest.
#'
#' @details Plots heatmaps of pre-calculated read counts around the centers of provided regions of interest, for example of ChIP peaks.
#' If multiple heatmap count tables are supplied as a named list, it clusters or sorts the rows of a chosen heatmap based on its count values 
#' and keeps the rows of all other heatmaps in the same order. 
#'
#' @param counts A named list of read counts across (peak) region centers, preferentially generated by \code{SummitHeatmap} for compatibility.
#' @param bamNames Character vector containing the names to describe the samples in \code{counts} (for example: "H3K9me3"). 
#' If no names are supplied, the names of the list provided in \code{counts} are used.
#' @param plotcols Character vector of colors of the same length as the number of heatmaps drawn, corresponding to the colors for the medianCpm 
#' value for the heatmaps. Default is darkblue for all heatmaps.
#' @param topCpm Numeric vector of the same length as the number of heatmaps drawn, indicating the CPM values that correspond to the black color. Any value above that will appear black.
#' If not specified, maximum CPM value in the current count table is used.
#' @param medianCpm Numeric vector of the same length as the number of heatmaps drawn, indicating the CPM values that correspond to the \code{plotcols}. 
#' Any value above that will appear darker, any value below that will appear more white. If not specified, the median CPM value in the current count table is used.
#' @param bottomCpm Numeric vector of the same length as the number of heatmaps drawn, indicating the CPM values that correspond to the white color. 
#' Any value below that will appear white. If not specified, the minimum CPM value in the current count table is used.
#' @param orderSample Integer scalar giving the index of the heatmap whose values the rows of all heatmaps will be ordered by. 
#' Defaults to 1 (the first heatmap). If set to 0, the \code{clusterSample} heatmap will clustered by hierarchical clustering instead of ordered.
#' @param orderWindows The heatmap windows that will be used for ordering. By default, the heatmap will be ordered by the mean of all windows.
#' If an integer scalar is supplied, it will be ordered by the middle window plus n windows on each side.
#' @param clusterSample Integer scalar giving the index of the heatmap whose values the rows of all heatmaps will be ordered by hierarchical clustering. Defaults to 1 (the first heatmap).
#' @param summarizing The function to average the heatmaps to generate the metaplot. The cpm or count values are first log2 transformed
#'  (with a pseudo count of 1 added, if necessary). Default is mean.
#' @param show_axis Show the axis of the cumulative plot above the heatmap. Default is FALSE.
#' @param use.log Logical scalar indicating whether the heatmap will be plotted in log2 scale. FALSE by default. A pseudo-count of 1 is added if Zeros are encountered in the matrix.
#' @param splitHM Character vector of the same length as the number of rows in each table in \code{counts}, which indicates for each row in \code{counts} the membership within a certain group.
#' The heatmap will then be split into these groups. For example, it could indicate if a peak overlaps a TSS or not. Default is that the heatmap will not be split. 
#' @param TargetHeight Integer scalar giving the number of rows the plotted heatmap should have after averaging. Default = 500.
#' @param MetaScale A character vector of the same length as heatmaps, with values "all" or "individual". Defaults to "all",
#' which means that the maximum of the y-axis for the metaplots will correspond to the maximum of all values in all heatmaps.
#' Otherwise the scale will correspond to the maximum of the current metaplot.
#'
#' @return A complex heatmap object containing a heatmap for each sample provided in \code{counts},
#' sorted by the values in the middle window (around peak summit) of the chosen \code{orderSample}.
#'
#' @importFrom ComplexHeatmap Heatmap HeatmapAnnotation anno_text max_text_width anno_lines ht_opt
#' @importFrom grid gpar
#' @importFrom grid unit
#' @importFrom grid pushViewport
#' @importFrom grid grid.rect
#' @importFrom grid grid.points
#' @importFrom grid grid.lines
#' @importFrom grid upViewport
#' @importFrom circlize colorRamp2
#' @importFrom stats dist median
#' @importFrom stats hclust
#'
#' @examples
#' library(methods)
#' counts <- list(matrix(abs(rnorm(2100,2,1)),ncol=21,
#' nrow=100,dimnames=list(1:100,-10:10)),
#'                matrix(abs(rnorm(2100,2,1)),ncol=21,
#'                nrow=100,dimnames=list(1:100,-10:10)))
#' names(counts) <- c("counts1","counts2")
#' DrawSummitHeatmaps(counts, orderSample=1, bottomCpm = c(0,0), 
#' topCpm=c(10,10),use.log=FALSE,TargetHeight=0)
#'
#' @export
DrawSummitHeatmaps <- function(counts, bamNames=names(counts), plotcols= rep("darkblue",length(bamNames)),use.log = FALSE,
                               topCpm, medianCpm, bottomCpm, splitHM, TargetHeight = 500,
                               orderSample = 1, orderWindows=NULL, clusterSample = 1, summarizing = "mean",show_axis=FALSE,
                               MetaScale=rep("all",length(bamNames))){
  
  # Check for required packages
  required_packages <- c("ComplexHeatmap", "grid", "circlize", "stats")
  missing_packages <- required_packages[!sapply(required_packages, requireNamespace, quietly = TRUE)]
  if (length(missing_packages) > 0) {
    stop(paste0("The following required packages are not installed: ", 
                paste(missing_packages, collapse=", "), 
                ". Please install them before using this function."))
  }
  
  # Function for log2 transformation (since it's used in the code but not defined)
  Log2Transform <- function(matrix, pseudocount = 1) {
    # Check if there are zeros or negative values in the matrix
    if (any(matrix <= 0)) {
      # Add pseudocount only if needed
      matrix <- log2(matrix + pseudocount)
    } else {
      matrix <- log2(matrix)
    }
    return(matrix)
  }
  
  # Function for matrix redimensioning (since it's used in the code but not defined)
  redim_matrix <- function(matrix, target_height, target_width) {
    if (is.null(matrix) || nrow(matrix) == 0 || ncol(matrix) == 0) {
      stop("Cannot redimension an empty matrix")
    }
    
    # Check dimensions
    if (nrow(matrix) < target_height) {
      warning(paste("Target height", target_height, "is larger than current matrix height", 
                    nrow(matrix), ". Matrix will be expanded."))
    }
    if (ncol(matrix) != target_width) {
      warning(paste("Target width", target_width, "differs from current matrix width", 
                    ncol(matrix), ". Columns will be adjusted."))
    }
    
    # Create a new matrix with the desired dimensions
    new_matrix <- matrix(0, nrow = target_height, ncol = target_width, 
                         dimnames = list(NULL, colnames(matrix)))
    
    # Fill the new matrix with averaged values
    if (nrow(matrix) >= target_height) {
      # Need to compress rows
      row_groups <- split(1:nrow(matrix), cut(1:nrow(matrix), target_height, labels = FALSE))
      for (i in 1:target_height) {
        if (length(row_groups[[i]]) > 0) {
          new_matrix[i, ] <- colMeans(matrix[row_groups[[i]], , drop = FALSE])
        }
      }
    } else {
      # Need to expand rows (less common)
      for (i in 1:target_height) {
        orig_row <- ceiling(i * nrow(matrix) / target_height)
        new_matrix[i, ] <- matrix[min(orig_row, nrow(matrix)), ]
      }
    }
    
    return(new_matrix)
  }
  
  # Input validation
  
  # Check counts parameter
  if (!is.list(counts)) {
    stop("'counts' must be a list of matrices")
  }
  
  if (length(counts) == 0) {
    stop("'counts' list cannot be empty")
  }
  
  # Check that each element in counts is a matrix
  for (i in seq_along(counts)) {
    if (!is.matrix(counts[[i]])) {
      stop(paste0("Element ", i, " in 'counts' list is not a matrix"))
    }
    
    # Check if the matrix has rownames and colnames
    if (is.null(rownames(counts[[i]]))) {
      stop(paste0("Matrix ", i, " in 'counts' list does not have rownames"))
    }
    if (is.null(colnames(counts[[i]]))) {
      stop(paste0("Matrix ", i, " in 'counts' list does not have colnames"))
    }
    
    # Check that colnames are numeric or can be converted to numeric
    if (any(is.na(suppressWarnings(as.numeric(colnames(counts[[i]])))))) {
      stop(paste0("Matrix ", i, " in 'counts' list has non-numeric colnames, which cannot be used for binning"))
    }
  }
  
  # Ensure all matrices have the same dimensions
  first_ncol <- ncol(counts[[1]])
  first_nrow <- nrow(counts[[1]])
  for (i in seq_along(counts)[-1]) {
    if (ncol(counts[[i]]) != first_ncol) {
      stop(paste0("Matrix ", i, " has ", ncol(counts[[i]]), " columns, but matrix 1 has ", first_ncol, " columns. All matrices should have the same number of columns."))
    }
    if (nrow(counts[[i]]) != first_nrow) {
      stop(paste0("Matrix ", i, " has ", nrow(counts[[i]]), " rows, but matrix 1 has ", first_nrow, " rows. All matrices should have the same number of rows."))
    }
  }
  
  # Check bamNames
  if (length(bamNames) != length(counts)) {
    stop(paste0("Length of 'bamNames' (", length(bamNames), ") must match length of 'counts' (", length(counts), ")"))
  }
  
  # Check plotcols
  if (length(plotcols) != length(bamNames)) {
    stop(paste0("Length of 'plotcols' (", length(plotcols), ") must match length of 'bamNames' (", length(bamNames), ")"))
  }
  
  # Validate color inputs
  for (color in plotcols) {
    if (!is.character(color)) {
      stop("'plotcols' must contain character strings representing colors")
    }
    # Could add more sophisticated color validation here if needed
  }
  
  # Check orderSample and clusterSample
  if (!is.numeric(orderSample) || length(orderSample) != 1) {
    stop("'orderSample' must be a single integer")
  }
  
  if (orderSample < 0 || orderSample > length(bamNames)) {
    stop(paste0("'orderSample' must be between 0 and the number of samples (", length(bamNames), ")"))
  }
  
  if (!is.numeric(clusterSample) || length(clusterSample) != 1) {
    stop("'clusterSample' must be a single integer")
  }
  
  if (clusterSample < 1 || clusterSample > length(bamNames)) {
    stop(paste0("'clusterSample' must be between 1 and the number of samples (", length(bamNames), ")"))
  }
  
  # Check orderWindows
  if (!is.null(orderWindows)) {
    if (!is.numeric(orderWindows) || length(orderWindows) != 1 || orderWindows < 0) {
      stop("'orderWindows' must be a single non-negative integer")
    }
    
    # Check if zero is in column names for ordering
    counts1 <- counts[[bamNames[orderSample]]]
    if (!"0" %in% colnames(counts1)) {
      warning("Column '0' not found in matrix. Ordering by windows around the center may not work as expected.")
    } else {
      # Check if the specified window range is valid
      zero.inx <- which(colnames(counts1) == "0")
      if (zero.inx - orderWindows < 1 || zero.inx + orderWindows > ncol(counts1)) {
        stop(paste0("'orderWindows' value of ", orderWindows, " is too large for the matrix with ", ncol(counts1), " columns"))
      }
    }
  }
  
  # Check summarizing function
  valid_summarizing <- c("mean", "median", "sum", "min", "max")
  if (!(summarizing %in% valid_summarizing) && !is.function(summarizing)) {
    stop(paste0("'summarizing' must be one of: ", paste(valid_summarizing, collapse=", "), " or a custom function"))
  }
  
  # Check use.log
  if (!is.logical(use.log) || length(use.log) != 1) {
    stop("'use.log' must be a single logical value (TRUE or FALSE)")
  }
  
  # Check TargetHeight
  if (!is.numeric(TargetHeight) || length(TargetHeight) != 1 || TargetHeight < 0) {
    stop("'TargetHeight' must be a single non-negative integer")
  }
  
  # Check MetaScale
  valid_scales <- c("all", "individual")
  if (length(MetaScale) != length(bamNames)) {
    stop(paste0("Length of 'MetaScale' (", length(MetaScale), ") must match length of 'bamNames' (", length(bamNames), ")"))
  }
  
  for (scale in MetaScale) {
    if (!(scale %in% valid_scales)) {
      stop(paste0("'MetaScale' values must be one of: ", paste(valid_scales, collapse=", ")))
    }
  }
  
  # Check show_axis
  if (!is.logical(show_axis) || length(show_axis) != 1) {
    stop("'show_axis' must be a single logical value (TRUE or FALSE)")
  }
  
  # Check CPM parameters if provided
  if (!missing(topCpm)) {
    if (!is.numeric(topCpm) || length(topCpm) != length(bamNames)) {
      stop(paste0("'topCpm' must be a numeric vector with length equal to the number of samples (", length(bamNames), ")"))
    }
  }
  
  if (!missing(medianCpm)) {
    if (!is.numeric(medianCpm) || length(medianCpm) != length(bamNames)) {
      stop(paste0("'medianCpm' must be a numeric vector with length equal to the number of samples (", length(bamNames), ")"))
    }
  }
  
  if (!missing(bottomCpm)) {
    if (!is.numeric(bottomCpm) || length(bottomCpm) != length(bamNames)) {
      stop(paste0("'bottomCpm' must be a numeric vector with length equal to the number of samples (", length(bamNames), ")"))
    }
  }
  
  # Check splitHM if provided
  if (!missing(splitHM)) {
    if (length(splitHM) != nrow(counts[[1]])) {
      stop(paste0("'splitHM' must be a vector with length equal to the number of rows in counts matrices (", nrow(counts[[1]]), ")"))
    }
  }
  
  # Function implementation starts here
  # Set some heatmap options:
  tryCatch({
    ComplexHeatmap::ht_opt(heatmap_column_names_gp = grid::gpar(fontface = "italic"), 
                           heatmap_column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
                           legend_border = "black",
                           heatmap_border = TRUE,
                           annotation_border = FALSE)
  }, error = function(e) {
    stop(paste0("Error setting heatmap options: ", e$message, 
                ". Make sure the ComplexHeatmap package is installed and loaded."))
  })
  
  # Select ordering by middle window or clustering approach
  if(orderSample > 0){
    counts1 <- (counts[[bamNames[orderSample]]])
    if (missing(orderWindows)){
      # Sort by mean of all values
      counts.sorted <- counts1[order(rowMeans(counts1)),]
    } else {
      # Sort by value at point 0 + n windows on each side
      zero.inx <- which(colnames(counts1) == "0")
      if (length(zero.inx) == 0) {
        warning("Column '0' not found. Using the middle column instead.")
        zero.inx <- ceiling(ncol(counts1)/2)
      }
      
      zero5.inx <- seq(zero.inx-orderWindows, zero.inx+orderWindows, 1)
      # Ensure indices are within matrix bounds
      zero5.inx <- zero5.inx[zero5.inx >= 1 & zero5.inx <= ncol(counts1)]
      
      if (length(zero5.inx) == 0) {
        stop("No valid columns for ordering. Check 'orderWindows' parameter.")
      }
      
      counts.sorted <- counts1[order(rowMeans(counts1[,zero5.inx, drop=FALSE])),]
    }
    counts.sorted.rev <- apply(counts.sorted, 2, rev)
  } else {
    counts1 <- (counts[[bamNames[clusterSample]]])
    tryCatch({
      dist_mat <- stats::dist(counts1, method = 'euclidean')
      hclust_avg <- stats::hclust(dist_mat, method = 'average')
    }, error = function(e) {
      stop(paste0("Error in clustering: ", e$message))
    })
  }
  
  # Initialize empty heatmap list
  ht_list = NULL
  
  # Generate bottom annotation with binnames
  if(max(abs(as.numeric(colnames(counts1)))) > 1000){
    binnames <- ifelse((as.numeric(colnames(counts1))/1000)%%1==0, colnames(counts1), "")
  } else {
    binnames <- ifelse((as.numeric(colnames(counts1))/100)%%1==0, colnames(counts1), "")
  }
  
  tryCatch({
    bin_anno <- ComplexHeatmap::HeatmapAnnotation(
      text = ComplexHeatmap::anno_text(
        binnames, 
        rot = 45, 
        location = grid::unit(1, "npc"), 
        just = "right",
        gp = grid::gpar(fontsize=8, fontface = "italic")
      ),
      annotation_height = ComplexHeatmap::max_text_width(binnames)
    )
  }, error = function(e) {
    stop(paste0("Error creating bin annotations: ", e$message))
  })
  
  # Take log2 of all counts tables (adding a pseudocount of 1 if 0s are present in the table)
  tryCatch({
    all.counts.log <- lapply(counts, Log2Transform, pseudocount = 1)
  }, error = function(e) {
    stop(paste0("Error in log2 transformation: ", e$message))
  })
  
  # Summarize the counts per bin for the cumulative plot
  tryCatch({
    if (summarizing %in% valid_summarizing) {
      summarizing_fn <- match.fun(summarizing)
      all.counts.median <- lapply(all.counts.log, function(x) apply(x, 2, summarizing_fn))
    } else if (is.function(summarizing)) {
      all.counts.median <- lapply(all.counts.log, function(x) apply(x, 2, summarizing))
    } else {
      stop("Invalid summarizing function")
    }
    
    all.counts.max <- max(unlist(all.counts.median))
    if (is.infinite(all.counts.max) || is.na(all.counts.max)) {
      warning("Maximum value calculation resulted in NA or Inf. Using 1 as default.")
      all.counts.max <- 1
    }
  }, error = function(e) {
    stop(paste0("Error in summarizing counts: ", e$message))
  })
  
  # Loop over all samples
  for (bam.sample in seq_along(bamNames)){
    tryCatch({
      # Select counts table for current sample
      counts.median <- (all.counts.median[[bamNames[bam.sample]]])
      
      # Select the y-axis limits of the metaplots:
      # a) maximum is the maximum of all counts tables,
      # b) maximum is the maximum of each counts table
      if (MetaScale[bam.sample] == "all"){
        anno.lims <- c(0, all.counts.max)
      } else {
        anno.lims <- c(0, max(0.1, max(counts.median, na.rm = TRUE)))
      }
      
      # Generate cumulative plot as heatmap annotation
      cumulative_anno <- ComplexHeatmap::HeatmapAnnotation(
        x = ComplexHeatmap::anno_lines(
          counts.median,
          size = grid::unit(1, "mm"),
          axis = show_axis,
          smooth = FALSE,
          ylim = anno.lims,
          gp = grid::gpar(lwd = 2, col = plotcols[bam.sample])
        ),
        show_annotation_name = FALSE
      )
      
      # Select log transformed or original count values based on use.log option
      if(use.log == TRUE){
        counts3 <- (all.counts.log[[bamNames[bam.sample]]])
      } else {
        counts3 <- (counts[[bamNames[bam.sample]]])
      }
      
      # Generate plotting color cutoffs based on cutoffs provided or min,median,max
      if (missing(bottomCpm)){
        BottomCpm <- min(counts3, na.rm = TRUE)
      } else {
        BottomCpm <- bottomCpm[bam.sample]
      }
      
      if (missing(medianCpm)){
        MedianCpm <- stats::median(counts3, na.rm = TRUE)
      } else {
        MedianCpm <- medianCpm[bam.sample]
      }
      
      if (missing(topCpm)){
        TopCpm <- max(counts3, na.rm = TRUE)
      } else {
        TopCpm <- topCpm[bam.sample]
      }
      
      if(BottomCpm == TopCpm){
        warning(paste0("bottomCpm = topCpm for sample ", bamNames[bam.sample], ": using slightly different values to avoid color mapping issues"))
        TopCpm <- TopCpm + 0.01
      }
      
      plotcolor <- circlize::colorRamp2(c(BottomCpm, MedianCpm, TopCpm), c("white", plotcols[bam.sample], "black"))
    }, error = function(e) {
      stop(paste0("Error in sample preparation for ", bamNames[bam.sample], ": ", e$message))
    })
    
    # Decide if bin 0 value ordering or clustering is used
    if(orderSample > 0){
      tryCatch({
        # Order the rows of the current counts table by the middle-value-sorted sample counts table
        counts.sorted <- counts3[row.names(counts.sorted.rev),]
        
        # Split the heatmap into many parts 
        if(missing(splitHM)){
          row_split = NULL
        } else {
          # Create a data frame for splitting
          splitHM.df <- data.frame(splitHM)
          row.names(splitHM.df) <- row.names(counts3)
          row_split1 <- splitHM.df[row.names(counts.sorted.rev), 1]
          split.groups <- unique(row_split1)
          split.fracs <- numeric(length(split.groups))
          for (g in seq_along(split.groups)){
            split.fracs[g] <- length(row_split1[row_split1 == split.groups[g]]) / length(row_split1)
          }
        }
        
        if (TargetHeight > 0 & missing(splitHM)){
          # Summarize the rows to make the heatmap look smoother
          counts.sorted <- redim_matrix(counts.sorted, target_height = TargetHeight, target_width = ncol(counts.sorted))
        }
        
        if (TargetHeight > 0 & !missing(splitHM)){
          # Summarize the rows to make the heatmap look smoother, taking splitting into consideration
          counts.sorted.split <- list()
          row_split <- character(0)
          for (g in seq_along(split.groups)){
            target_rows <- max(1, round(TargetHeight * split.fracs[g]))
            counts.subset <- counts.sorted[row_split1 == split.groups[g], , drop=FALSE]
            
            if (nrow(counts.subset) > 0) {
              counts.sorted.split[[g]] <- redim_matrix(
                counts.subset, 
                target_height = target_rows, 
                target_width = ncol(counts.sorted)
              )
              row_split <- c(row_split, rep(split.groups[g], nrow(counts.sorted.split[[g]])))
            } else {
              warning(paste0("No rows found for group ", split.groups[g], ". Skipping this group."))
            }
          }
          
          if (length(counts.sorted.split) > 0) {
            counts.sorted <- do.call("rbind", counts.sorted.split)
          } else {
            stop("No valid data after splitting. Check your splitHM parameter.")
          }
        } 
        
        if (TargetHeight == 0 & !missing(splitHM)){
          row_split <- row_split1
        }
      }, error = function(e) {
        stop(paste0("Error in heatmap sorting/splitting: ", e$message))
      })
      
      # Generate the heatmap list entry for the current sample and append to complete heatmap list
      tryCatch({
        ht_list <- ht_list + ComplexHeatmap::Heatmap(
          counts.sorted,
          name = bamNames[bam.sample], 
          cluster_rows = FALSE, 
          cluster_columns = FALSE,
          column_order = colnames(counts.sorted), 
          col = plotcolor,
          column_title = bamNames[bam.sample], 
          column_title_rot = 45,
          show_row_names = FALSE, 
          show_column_names = FALSE,
          bottom_annotation = bin_anno,
          top_annotation = cumulative_anno,
          use_raster = FALSE,
          row_split = row_split
        )
      }, error = function(e) {
        stop(paste0("Error creating ordered heatmap for sample ", bamNames[bam.sample], ": ", e$message))
      })
      
    } else {
      tryCatch({
        # Split the heatmap into many parts 
        if(missing(splitHM)){
          row_split = NULL
        } else {
          row_split = splitHM
        }
        
        # Generate the heatmap list entry for the current sample (sorted by clustering)
        ht_list <- ht_list + ComplexHeatmap::Heatmap(
          counts3,
          name = bamNames[bam.sample], 
          cluster_rows = hclust_avg, 
          cluster_columns = FALSE,
          column_order = colnames(counts3), 
          col = plotcolor,
          column_title = bamNames[bam.sample], 
          column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
          show_row_names = FALSE, 
          show_column_names = FALSE,
          bottom_annotation = bin_anno,
          top_annotation = cumulative_anno,
          use_raster = TRUE, 
          raster_device = "CairoPNG",
          raster_quality = 1,
          row_split = row_split, 
          show_row_dend = FALSE,
          show_column_dend = FALSE
        )
      }, error = function(e) {
        stop(paste0("Error creating clustered heatmap for sample ", bamNames[bam.sample], ": ", e$message))
      })
    }
  }
  
  return(ht_list)
}
fmi-basel/gbuehler-MiniChip documentation built on June 13, 2025, 6:15 a.m.