R/plotting_funcs.R

Defines functions color_levels quantile_breaks group_colors cluster_pair_mr_volcano cluster_mr_heatmap cluster_mr_volcano qc_plot

Documented in cluster_mr_heatmap cluster_mr_volcano cluster_pair_mr_volcano color_levels group_colors qc_plot quantile_breaks

#' Generates QC violin plots for depth, detected genes, and MT% metrics.
#' 
#' @param raw.counts Matrix of raw gene expression data (features X samples).
#' @param species One of `c('hum', 'mur')`, specifying human or murine data respectively. If not specified, assumes human.
#' @param genes One of `c('symb', 'ensg')`, specifying gene symbols or ENSG names respectively. If not specified, assumed symbols.
#' @return A ggplot object.
#' @export
qc_plot <- function(raw.counts, species = c('hum', 'mur'), genes = c('symb', 'ensg')) {
  require(ggplot2)
  require(ggpubr)
  data(mt_genes)
  
  # check arguments
  match.arg(species)
  if (missing(species)) { species <- 'hum' }
  match.arg(genes)
  if (missing(genes)) { genes <- 'symb' }
  
  # collect statistics
  sample.depth <- colSums(raw.counts)
  sample.genes <- apply(raw.counts, 2, function(x) {length(which(x > 0))})
  mtg <- intersect(mt_genes[[paste(species, genes, sep = '.')]], rownames(raw.counts))
  sample.mt <- apply(raw.counts, 2, function(x) {
    sum(x[mtg])
  })
  sample.mt <- sample.mt / sample.depth
  qc.df <- data.frame('Depth' = sample.depth,
                      'Genes' = sample.genes,
                      'MT.per' = sample.mt,
                      'Sample' = rep('raw', length(sample.depth)))
  
  # plot
  depth.plot <- ggplot(qc.df, aes(x = Sample, y = Depth)) + geom_violin(color = '#F8766D', fill = '#F8766D') +
    labs(y = 'Depth', x = '') + theme_bw() + theme(axis.text.x = element_blank())
  gene.plot <- ggplot(qc.df, aes(x = Sample, y = Genes)) + geom_violin(color = '#00BA38', fill = '#00BA38') +
    labs(y = 'Detected Genes', x = '') + theme_bw() + theme(axis.text.x = element_blank())
  mt.plot <- ggplot(qc.df, aes(x = Sample, y = MT.per)) + geom_violin(color = '#619CFF', fill = '#619CFF') +
    labs(y = 'MT%', x = '') + theme_bw() + theme(axis.text.x = element_blank()) +
    scale_y_continuous(labels = scales::percent)
  plot.obj <- ggarrange(plotlist = list(depth.plot, gene.plot, mt.plot), ncol = 3)
  
  return(plot.obj)
}

#' Generates cluster-specific volcano plots from the given master regulator object.
#' 
#' @param mr.obj Master regulator object, such as that generated by `cluster_signature_mrs`.
#' Should be a list with a `cluster.narnea` element, containing cluster-specific NaRnEA results (NES and PES matrices).
#' @param num.mrs Number of top / bottom MRs to label. Default of 10.
#' @return A ggplot object.
#' @export
cluster_mr_volcano <- function(mr.obj, num.mrs = 10) {
  require(ggplot2)
  require(ggrepel)
  
  # generate plot data frame for each cluster
  clust.names <- colnames(mr.obj$cluster.narnea$NES)
  df.list <- list()
  for (cn in clust.names) {
    plot.df <- data.frame('PES' = mr.obj$cluster.narnea$PES[,cn],
                          'NES' = abs(mr.obj$cluster.narnea$NES[,cn]))
    # set colors
    col.vec <- rep(0, nrow(plot.df))
    sig.prots <- which(p.adjust(pnorm(plot.df$NES, lower.tail = FALSE), method = 'BH') < 0.05)
    col.vec[sig.prots] <- 1 * sign(plot.df$PES[sig.prots])
    plot.df$Significance <- as.factor(col.vec)
    # set labels
    pes.sort <- sort(mr.obj$cluster.narnea$PES[sig.prots,cn]); 
    pos.mrs <- names(tail(pes.sort, num.mrs)); neg.mrs <- names(head(pes.sort, num.mrs))
    label.vec <- rep('', nrow(plot.df)); names(label.vec) <- rownames(plot.df)
    label.vec[pos.mrs] <- pos.mrs; label.vec[neg.mrs] <- neg.mrs
    plot.df$Label <- label.vec
    # generate plot
    df.list[[cn]] <- plot.df
  }
  # collapse and render
  plot.df <- Reduce(rbind, df.list)
  plot.df[['Cluster']] <- rep(clust.names, each = nrow(mr.obj$cluster.narnea$NES))
  volcano.plot <- ggplot(plot.df, aes(PES, NES)) + geom_point(aes(color = Significance)) +
    scale_color_manual(values = c('blue', 'darkgrey', 'red')) + 
    geom_label_repel(aes(label = Label, color = Significance), box.padding = 0.5,
                     max.overlaps = Inf, show.legend = FALSE) +
    facet_wrap(vars(Cluster), ncol = 3) + 
    xlim(-1, 1) +
    labs(x = 'PES', y = '|NES|', title = 'Cluster Master Regulators') +
    geom_hline(yintercept = 0) + geom_vline(xintercept = 0) +
    theme(legend.position = 'none')
  
  return(volcano.plot)
}

#' Generates a heatmap of cluster-specific MR activity.
#' Alternatively, can be provided with a pre-selected set of proteins, typically canonical markers.
#' 
#' @param dat.mat Matrix of either scaled expresison data or PES generated by NaRnEA (featues X samples).
#' @param dat.type One of `c('gexp', 'pact')`, specifying expression or activity data. ASsumes activity if not specified.
#' @param clust.vec Named clustering vector.
#' @param mr.obj Master regulator object, such as that generated by `cluster_signature_mrs`.
#' Should be a list with a `cluster.narnea` element, containing cluster-specific NaRnEA results (NES and PES matrices).
#' @param num.mrs Number of master regulators to display. Default of 10.
#' @param marker.set Optional list of markers to use in lieu of master regulators.
#' @param group.means Flag to plot the mean of each group rather than each sample. Default of FALSE.
#' @param scale.rows Flag to scale rows for plot color. Default of FALSE.
#' @param plot.title Optional plot title argument. Will be generated according to other arguments if not specified.
#' @param clust.rows Clusters rows if set to true. Default of FALSE.
#' @return A ComplexHeatmap object.
#' @export
cluster_mr_heatmap <- function(dat.mat, dat.type = c('gexp', 'pact'), clust.vec, mr.list, num.mrs = 10, reg.class = c('regulator', 'marker'), 
                               marker.set = NULL, group.means = FALSE, scale.rows = FALSE, plot.title = NULL, clust.rows = FALSE) {
  require(circlize)
  require(ComplexHeatmap)
  dat.mat <- as.matrix(dat.mat)
  
  match.arg(reg.class)
  if (missing(reg.class)) {reg.class = 'regulator'}
  match.arg(dat.type)
  if (missing(dat.type)) {dat.type = 'pact'}
  
  # set title
  if (!is.null(plot.title)) {
    map.title <- plot.title
  } else if (is.null(marker.set)) {
    if (reg.class == 'regulator') {
      map.title <- switch(dat.type,
                          'gexp' = 'Cluster Top DE Gene Expression - Candidate Regulators',
                          'pact' = 'Cluster Master Regulator Activity')
    } else {
      map.title <- switch(dat.type,
                          'gexp' = 'Cluster Top DE Gene Expression - Surface Markers',
                          'pact' = 'Cluster Surface Marker Activity')
    }
  } else {
    map.title <- switch(dat.type,
                        'gexp' = 'Cluster Marker Expression',
                        'pact' = 'Cluster Marker Activity')
  }
  # collect parameters
  clust.names <- sort(unique(clust.vec))
  num.clust <- length(clust.names)
  clust.vec <- sort(clust.vec)
  clust.colors <- group_colors(num.clust); names(clust.colors) <- clust.names
  # find protein set
  if (!missing(marker.set)) {
    if (class(marker.set) != 'data.frame') {
      feature.set <- as.data.frame(intersect(marker.set, rownames(dat.mat)))
    } else {
      present.markers <- which(marker.set[,1] %in% rownames(dat.mat))
      feature.set <- marker.set[present.markers,]
    }
  } else {
    feature.set <- get_mr_set(mr.list, num.mrs, reg.class, reg.sig = 'pos')
    if (is.null(feature.set)) {
      return()
    }
    feature.set[,2] <- tryCatch(as.numeric(feature.set[,2]), warning = function(w) {return(feature.set[,2])})
  }
  # set plot data
  plot.mat <- dat.mat[feature.set[,1], names(clust.vec)]
  print(dim(plot.mat))
  # take mean if specified; build column annotations appropriately
  if (group.means) {
    # take group means
    plot.mat <- Reduce('cbind', lapply(clust.names, function(x) {
      clust.samps <- which(clust.vec == x)
      clust.mean <- rowMeans(plot.mat[, clust.samps, drop = FALSE])
    }))
    colnames(plot.mat) <- clust.names
    # set annotations
    col.annot <- HeatmapAnnotation('Cluster' = clust.names, col = list('Cluster' = clust.colors))
    col.gaps <- 1:length(clust.names)
  } else {
    # set annotations
    col.annot <- HeatmapAnnotation('Cluster' = clust.vec, col = list('Cluster' = clust.colors))
    col.gaps <- clust.vec
  }
  # set plot colors
  if (scale.rows) {
    plot.mat <- t(apply(plot.mat, 1, scale))
    plot.mat[which(is.nan(plot.mat))] <- 0
  }
  col.breaks <- quantile_breaks(plot.mat)
  col.fun <- switch(dat.type,
                    'gexp' = col.fun <- colorRamp2(col.breaks, color_levels('gexp', length(col.breaks))),
                    'pact' = col.fun <- colorRamp2(col.breaks, color_levels('pact', length(col.breaks))))
  # build row annotations + set title
  if (is.null(marker.set)) {
    row.annot <- rowAnnotation('Cluster' = feature.set[,2], 
                               col = list('Cluster' = clust.colors), 
                               show_annotation_name = FALSE,
                               show_legend = FALSE) 
    row.gaps <- feature.set[,2]
  } else if (ncol(feature.set) > 1) {
    m.group.colors <- group_colors(length(unique(feature.set[,2])), offset = 30)
    names(m.group.colors) <- unique(feature.set[,2])
    row.annot <- rowAnnotation('Marker Group' = feature.set[,2], 
                               col = list('Marker Group' = m.group.colors), 
                               show_annotation_name = FALSE,
                               show_legend = TRUE) 
    row.gaps <- feature.set[,2]
  } else {
    row.annot <- NULL
    row.gaps <- NULL
  }
  # create plot
  heatmap.obj <- Heatmap(plot.mat, name = switch(dat.type, 'gexp' = 'Expression', 'pact' = 'PES'),
                         col = col.fun,
                         top_annotation = col.annot, column_split = col.gaps,
                         left_annotation = row.annot, row_split = row.gaps,
                         cluster_rows = clust.rows, cluster_columns = FALSE,
                         show_row_names = TRUE, show_column_names = FALSE,
                         column_title = map.title, row_title = NULL)
  return(heatmap.obj)
}

#' Generates cluster-specific volcano plots from the given master regulator object.
#' 
#' @param mr.obj Master Regulator object generated by pairwise_cluster_mrs.
#' Should be a list with a `mr.narnea` element, containing cluster-v-cluster NaRnEA results (NES and PES matrices).
#' Names should be structured as `cn1.v.cn2`
#' @param clust.vec Clustering vector.
#' @param num.mrs Number of top / bottom MRs to label. Default of 10.
#' @return A ggplot object.
#' @export
cluster_pair_mr_volcano <- function(mr.obj, clust.vec, num.mrs = 10) {
  require(ggplot2)
  require(ggrepel)
  
  df.list <- list()
  clust.names <- as.character(sort(unique(clust.vec)))
  num.clust <- length(clust.names)
  for (i in 1:(num.clust - 1)) {
    for (j in (i+1):num.clust) {
      i.clust <- clust.names[i]
      j.clust <- clust.names[j]
      ivj.name <- paste(i.clust, 'v', j.clust, sep = '.')
      # build df for this comp
      plot.df <- data.frame('PES' = mr.obj$mr.narnea$PES[,ivj.name],
                            'NES' = abs(mr.obj$mr.narnea$NES[,ivj.name]),
                            'Test' = rep(i.clust, nrow(mr.obj$mr.narnea$PES)),
                            'Ref' = rep(j.clust, nrow(mr.obj$mr.narnea$PES)))
      # set colors
      col.vec <- rep(0, nrow(plot.df))
      sig.prots <- which(p.adjust(pnorm(plot.df$NES, lower.tail = FALSE), method = 'BH') < 0.05)
      col.vec[sig.prots] <- 1 * sign(plot.df$PES[sig.prots])
      plot.df$Significance <- as.factor(col.vec)
      # set labels
      pes.sort <- sort(mr.obj$mr.narnea$PES[sig.prots, ivj.name]); 
      pos.mrs <- names(tail(pes.sort, num.mrs)); neg.mrs <- names(head(pes.sort, num.mrs))
      label.vec <- rep('', nrow(plot.df)); names(label.vec) <- rownames(plot.df)
      label.vec[pos.mrs] <- pos.mrs; label.vec[neg.mrs] <- neg.mrs
      plot.df$Label <- label.vec
      # generate plot
      df.list[[ivj.name]] <- plot.df
    }
  }
  # collapse and render
  plot.df <- Reduce(rbind, df.list)
  volcano.plot <- ggplot(plot.df, aes(PES, NES)) + geom_point(aes(color = Significance)) +
    scale_color_manual(values = c('blue', 'darkgrey', 'red')) + 
    geom_label_repel(aes(label = Label, color = Significance), box.padding = 0.5,
                     max.overlaps = Inf, show.legend = FALSE) +
    facet_grid(Test ~ Ref, drop = TRUE) + 
    xlim(-1, 1) +
    labs(x = 'PES', y = '|NES|', title = 'Cluster Master Regulators') +
    geom_hline(yintercept = 0) + geom_vline(xintercept = 0) +
    theme(legend.position = 'none')
  
  return(volcano.plot)
}

#' Identifies a vector of colors for a given number of groups.
#' 
#' @param k Number of groups.
#' @param offset Optional argument to shift colors along color wheel.
#' @return A vector of hues
#' @export
group_colors <- function(k, offset = 0) {
  hues <- seq(15, 375, length = k + 1) + offset
  return(hcl(h = hues, l = 65, c = 100)[1:k])
}

#' Generates breaks for a color scale based on quantiles.
#'
#' @param dat.mat Data matrix (features X samples).
#' @param n Number of breaks to generate. If not specified, uses first three stdevs.
#' @return Numeric vector of break values.
#' @export
quantile_breaks <- function(dat.mat, n) {
  if (!missing(n)) {
    breaks <- quantile(dat.mat, probs = seq(from = 0, to = 1, length.out = n))
  } else {
    breaks <- quantile(dat.mat, c(0.003, 0.05, 0.32, 0.5, 0.68, 0.95, 0.997))
  }
  return(unique(breaks))
}

#' Returns color gradient for the specified data type (green/purple for Gene Expression; red/blue for proteina ctivity)
#'
#' @param data.type Type of data to use; either 'gexp' or 'pact'
#' @param num.colors Number of colors to return.
#' @return Vector of colors.
#' @export
color_levels <- function(data.type = c('gexp', 'pact'), num.colors) {
  require(grDevices)
  require(RColorBrewer)
  
  match.arg(data.type)
  
  if (data.type == 'gexp') {
    col.func <- colorRampPalette(rev(brewer.pal(11, 'PRGn')))
  } else if (data.type == 'pact') {
    col.func <- colorRampPalette(rev(brewer.pal(11, 'RdBu')))
  } else {
    print("Error: Not a valid data type; must be one of 'gexp' or 'vip'")
    return(0)
  }
  return(col.func(num.colors))
}
califano-lab/PISCES documentation built on Jan. 11, 2023, 5:34 a.m.