R/other.R

Defines functions save_plots_to_pdf plot_sample_clustering plot_chromosome plot_biotypes get_gene_id

Documented in get_gene_id plot_biotypes plot_chromosome plot_sample_clustering save_plots_to_pdf

#' Get all gene IDs in a DESeqDataSet for a given gene name.
#'
#' @param gene_name A gene name
#' @param dds A DESeqDataSet
#'
#' @return A character vector
#' @examples
#' get_gene_id("HBA1", T47D)
#' @export
get_gene_id <- function(gene_name, dds){
  rowData(dds)[rowData(dds)$gene_name==gene_name,"gene_name",drop=FALSE]
}


#' Plot number of counts per sample and biotype
#'
#' Plot the total number of counts for each sample and the major classes of
#' ENSEMBL gene biotypes (protein coding, lncRNA, etc.)
#' @param dds A DESeqDataSet
#'
#' @return A ggplot object of the ggplot2 package.
#' @examples
#' plot_biotypes(T47D)
#'
#' @export
plot_biotypes <- function(dds) {

  # use alternative sizeFactor estimation, if
  # every gene contains a zero count
  if(all(rowSums(assay(dds)==0)>0)){
    message("Every gene contains a zero count. Using slow size factor estimation type 'iterate'.")
    dds <- DESeq2::estimateSizeFactors(dds, type="iterate")
  } else {
    dds <- DESeq2::estimateSizeFactors(dds)
  }


  biotypes_df <- counts(dds, normalized = T) %>%
    as_tibble(rownames = "gene_id") %>%
    pivot_longer(-gene_id, names_to = "sample_id", values_to = "count") %>%
    mutate(sample_id = factor(sample_id, levels = str_sort(unique(sample_id), numeric = T))) %>%
    left_join(
      rowData(dds) %>%
        as_tibble(rownames = "gene_id") %>%
        mutate(biotype = dplyr::case_when(
          gene_biotype == "protein_coding" ~ "protein coding",
          gene_biotype %in% c("lncRNA","lincRNA") ~ "lncRNA",
          gene_biotype == "snRNA" ~ "snRNA",
          gene_biotype %in% c("Mt_rRNA", "Mt_tRNA") ~ "MT-RNA",
          str_detect(gene_biotype, "pseudogene") ~ "pseudogene",
          gene_biotype == "TR_C_gene" ~ "T-cell receptor C",
          gene_biotype == "TR_D_gene" ~ "T-cell receptor D",
          gene_biotype == "TR_J_gene" ~ "T-cell receptor J",
          gene_biotype == "TR_V_gene" ~ "T-cell receptor V",
          TRUE ~ "other"
        )),
      by = "gene_id"
    ) %>%
    group_by(sample_id, biotype) %>%
    summarize(total_count = sum(count), .groups = "drop")

  # order biotypes legend by total count
  biotype_order <- biotypes_df %>%
    group_by(biotype) %>%
    summarize(sum = sum(total_count)) %>%
    arrange(-sum) %>%
    pull(biotype)

  pal_jco <- c(
    "#0073C2FF", "#EFC000FF", "#868686FF", "#CD534CFF", "#7AA6DCFF",
    "#003C67FF", "#8F7700FF", "#3B3B3BFF", "#A73030FF", "#4A6990FF"
  )

  biotypes_df %>%
    ggplot(aes(sample_id, total_count, group = biotype, color = biotype)) +
    geom_point() +
    geom_line() +
    scale_y_log10(breaks = 10^(0:10)) +
    scale_color_manual(values = pal_jco, breaks = biotype_order) +
    labs(x = "sample ID", y = "total normalized count") +
    cowplot::theme_cowplot() +
    theme(
      axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5),
      panel.grid.major.x = element_line(linetype = "66", size = rel(.4), color = "gray")
    )
}

#' Plot gene expression along a chromosome
#'
#' @param vsd An object generated by `DESeq2::vst()`
#' @param chr A string denoting a chromosome as annotated by ENSEMBL, e.g.
#' '1', '2', 'X', 'Y', 'MT'
#' @param scale Whether to scale the columns of the heatmap
#' @param trunc_val Truncate the expression matrix to this value prior to plotting. This is useful
#' if some very high expression values dominate the heatmap. By default, the heatmap is truncated
#' to expression values at most 3 standard deviations from the mean.
#'
#' @return A Heatmap-class object of the `ComplexHeatmap` package that contains the heatmap of expression values.
#' @examples
#' library("DESeq2")
#' chr1 <- T47D[which(mcols(T47D)$chromosome=="1"),]
#' vsd <- vst(chr1)
#' plot_chromosome(vsd, chr="1")
#'
#' @export
plot_chromosome <- function(vsd, chr, scale = FALSE, trunc_val = NULL) {

  chr_exp_mat <- assay(vsd) %>%
    as_tibble(rownames = "gene_id") %>%
    pivot_longer(-gene_id, names_to = "sample_id", values_to = "log_count") %>%
    left_join(
      rowData(vsd) %>%
        as_tibble(rownames = "gene_id"),
      by = "gene_id"
    ) %>%
    filter(chromosome == chr) %>%
    arrange(chr_start, sample_id) %>%
    select(gene_id, sample_id, log_count) %>%
    pivot_wider(names_from = "sample_id", values_from = "log_count") %>%
    column_to_rownames("gene_id") %>%
    as.matrix() %>%
    t() %>%
    scale(scale = scale)

  if(is.null(trunc_val)){
    trunc_val <- 3*stats::sd(chr_exp_mat)
  }
  chr_exp_mat <- pmax( pmin(chr_exp_mat, trunc_val), -trunc_val)
  max_abs <- max(abs(chr_exp_mat))

  suppressMessages(
    ComplexHeatmap::Heatmap(chr_exp_mat,
      col = circlize::colorRamp2(c(-max_abs, 0, max_abs), c("royalblue2", "white", "red2")),
      show_row_names = TRUE,
      show_column_names = FALSE,
      cluster_columns = FALSE,
      row_names_gp = grid::gpar(fontsize = 10),
      row_dend_width = grid::unit(.05, "npc"),
      column_title = paste0("genes along chromosome ", chr),
      name = "expression"
    )
  )
}

#' Plot clustering of samples in a distance heatmap
#'
#' @param se A SummarizedExperiment object.
#' @param n_feats Number of top-variable features (genes) to consider
#' @param anno_vars Character vector of columns in `colData(se)` to annotate samples
#' @param anno_title The title of the color legend for `anno_vars`
#' @param distance The type of distance metric to consider. Either 'euclidean', 'pearson' or 'spearman'
#' @param ... Other arguments passed on to ComplexHeatmap::Heatmap()
#'
#' @return A Heatmap-class object of the `ComplexHeatmap` package that contains the heatmap of pairwise sample distances.
#'
#' @examples
#' \donttest{
#' library("DESeq2")
#' dds <- makeExampleDESeqDataSet(m=8, interceptMean=10)
#' vsd <- vst(dds)
#' plot_sample_clustering(vsd)
#' }
#' @export
plot_sample_clustering <- function(se, n_feats = 500, anno_vars = NULL, anno_title="group", distance = "euclidean", ...) {
  if (!all(anno_vars %in% colnames(colData(se)))) {
    stop("An element of 'anno_vars' is not a column of colData(se).")
  }

  # get metadata
  meta <- data.frame(colData(se))

  # column annotation
  if (!is.null(anno_vars)) {
    top_anno <- ComplexHeatmap::HeatmapAnnotation(df = meta[, anno_vars], which = "column", name = anno_title)
  } else {
    top_anno <- NULL
  }

  assay_mat <- assay(se)[matrixStats::rowVars(assay(se)) %>%
    order(decreasing = T) %>%
    head(n_feats), ]

  # prepare different plot types
  if (distance == "euclidean") {
    mat <- as.matrix(dist(t(assay_mat)))
    color_name <- "euclidean\ndistance"
  } else if (distance %in% c("pearson", "spearman")) {
    mat <- as.matrix(as.dist((1 - cor(assay_mat, method = distance)) / 2))
    color_name <- paste0(distance, " corr.\ndistance")
  } else {
    stop("Type must be one of 'euclidean', 'pearson' or 'spearman'.")
  }

  RColorBrewer_blues <- c(
    "#F7FBFF", "#DEEBF7", "#C6DBEF", "#9ECAE1", "#6BAED6",
    "#4292C6", "#2171B5", "#08519C", "#08306B")

  ComplexHeatmap::Heatmap(
    mat,
    col = circlize::colorRamp2(seq(0, max(mat), length.out = 9), rev(RColorBrewer_blues)),
    show_column_names = F,
    row_dend_width = grid::unit(.07, "npc"),
    column_dend_height = grid::unit(.07, "npc"),
    top_annotation = top_anno,
    name = color_name,
    ...
  )
}

#' Save list of plots to PDF
#'
#' This function takes a list of plots as input and makes a pdf with `ncol` x `nrow` plots per page.
#' @param plots List of plots that is passed to the `plotlist` argument of `cowplot::plot_grid`
#' @param file file where the plots are saved
#' @param ncol number of columns per page for the grid of plots
#' @param nrow number of rows per page for the grid of plots
#' @param subfig_width width of a plot of the grid in inches
#' @param subfig_height height of a plot of the grid in inches
#' @param legend_position either 'original' if the original legend of each sub-plot is shown, 'none', if no legend should be shown in any of the sub-plots, 'bottom',
#' if no legend should be shown in the sub plots and one shared legend at the bottom or 'right', which is same as 'bottom', but shown on the right
#'
#' @return The function returns nothing but is called for it's side effect, which is to save a pdf of plots to the filesystem.
#'
#' @examples
#' \donttest{
#' library("ggplot2")
#' manuf <- unique(mpg$manufacturer)
#' plots <- lapply(manuf, function(x){
#'   df <- mpg[mpg$manufacturer==x,]
#'   ggplot(df, aes(cty, hwy)) +
#'     geom_point() +
#'     labs(title=x)
#' })
#' save_plots_to_pdf(plots, ncol=3, nrow=2)
#' }
#'
#' @export
save_plots_to_pdf <- function(plots, file="plots.pdf", ncol, nrow, subfig_width=subfig_height*16/9, subfig_height=2.5, legend_position="original") {
  # sometimes plots have warnings and suppression is needed that the console is not spammed with these messages
  suppressWarnings({
    num_plots <- length(plots)
    pages <- max(ceiling(num_plots/(ncol*nrow)),1)

    # extract shared legend and if neccessary remove legend from all plots
    if(legend_position != "original"){
      if(legend_position == "bottom") {
        legend_bottom <- cowplot::get_legend(plots[[1]]+theme(legend.position = "bottom"))
      } else if (legend_position =="right") {
        legend_right <- cowplot::get_legend(plots[[1]]+theme(legend.position = "right"))
      }
      # remove legend from all plots
      plots <- purrr::map(plots, ~.x+theme(legend.position = "none"))
    }

    grDevices::pdf(file=file, width=subfig_width*ncol, height=subfig_height*nrow)
    purrr::walk(1:pages, function(p) {
      message(paste0("Page ",p,"/",pages))

      if(legend_position %in% c("original","none")){
        print(
          cowplot::plot_grid(plotlist = plots[ ((p-1)*ncol*nrow+1) : min(p*ncol*nrow,num_plots) ], ncol=ncol, nrow=nrow)
        )
      } else if (legend_position=="bottom"){
        print(
          cowplot::plot_grid(
            cowplot::plot_grid(plotlist = plots[ ((p-1)*ncol*nrow+1) : min(p*ncol*nrow,num_plots) ], ncol=ncol, nrow=nrow),
            legend_bottom,
            rel_heights=c(1,.03), ncol=1
          )
        )
      } else {
        print(
          cowplot::plot_grid(
            cowplot::plot_grid(plotlist = plots[ ((p-1)*ncol*nrow+1) : min(p*ncol*nrow,num_plots) ], ncol=ncol, nrow=nrow),
            legend_right,
            rel_widths=c(1,.05), nrow=1
          )
        )
      }
    })
    grDevices::dev.off()
  })
}

Try the RNAseqQC package in your browser

Any scripts or data that you put into this service are public.

RNAseqQC documentation built on July 1, 2024, 9:07 a.m.