R/read_end_metaheatmap_plot.R

Defines functions rends_heat

Documented in rends_heat

#' Metaheatmaps of the two extremities of the reads.
#'
#' This function generates four metaheatmaps displaying the abundance of the 5'
#' and 3' extremity of reads mapping around the start and the stop codon of
#' annotated CDSs, stratified by their length. Multiple samples and replicates
#' can be handled.
#'
#' @param data Either list of data tables or GRangesList object from
#'   \code{\link{bamtolist}}, \code{\link{bedtolist}},
#'   \code{\link{length_filter}} or \code{\link{psite_info}}.
#' @param annotation Data table as generated by \code{\link{create_annotation}}.
#' @param sample Either character string, character string vector or named list
#'   of character string(s)/character string vector(s) specifying the name of
#'   the sample(s) and replicate(s) of interest. If a list is provided, each
#'   element of the list is considered as an independent sample associated with
#'   one ore multiple replicates. Multiple samples and replicates are handled
#'   and visualised according to \code{multisamples} and \code{plot_style}.
#' @param multisamples Either "average" or "independent". It specifies how to
#'   handle multiple samples and replicates stored in \code{sample}:
#'    * if \code{sample} is a character string vector and \code{multisamples} is
#'   set to "average" the elements of the vector are considered as replicates
#'   of one sample and a single heatmap is returned.
#'    * if \code{sample} is a character string vector and \code{multisamples} is
#'   set to "independent", each element of the vector is analysed independently
#'   of the others. The number of plots returned and their organization is
#'   specified by \code{plot_style}.
#'    * if \code{sample} is a list, \code{multisamples} must be set to "average".
#'   Each element of the list is analysed independently of the others, its
#'   replicates averaged and its name reported in the plot. The number of plots
#'   returned and their organization is specified by \code{plot_style}.
#'   Note: when this parameter is set to "average" the heatmap associated with
#'   each sample displays the length- and position- specific mean signal
#'   computed across the replicates. Default is "average".
#' @param plot_style Either "split" or "facet". It specifies how to organize and
#'   display multiple heatmaps:
#'   * "split": one heatmap for each sample is returned as an independent
#'   ggplot object.
#'   * "facet": the heatmaps are placed one below the other, in independent
#'   boxes.
#'   Default is "split".
#' @param scale_factors Either "auto", a named numeric vector or "none". It
#'   specifies how heatmap values should be scaled before merging multiple
#'   replicates (if any):
#'   * "auto": each heatmap is scaled so that the average of all values is 1.
#'   * named numeric vector: \code{scale_factors} must be the same length of
#'   unlisted \code{sample} and each scale factor must be named after the
#'   corresponding string in unlisted \code{sample}. No specific order is
#'   required. Each heatmap value is multiplied by the matching scale factor.
#'   * "none": no scaling is applied.
#'   Default is "auto".
#' @param transcripts Character string vector listing the name of transcripts to
#'   be included in the analysis. Default is NULL, i.e. all transcripts are
#'   used. Please note: transcripts with either 5' UTR, coding sequence or 3'
#'   UTR shorter than \code{utr5l}, \eqn{2*}\code{cdsl} and \code{utr3l},
#'   respectively, are automatically discarded.
#' @param length_range Integer or integer vector for restricting the plot to a
#'   chosen range of read lengths. Default is NULL, i.e. all read lengths are
#'   used. If specified, this parameter prevails over \code{cl}.
#' @param cl Integer value in [1,100] specifying a confidence level for
#'   restricting the plot to an automatically-defined range of read lengths. The
#'   new range is computed according to the most frequent read lengths, which
#'   accounts for the cl% of the sample and is defined by discarding the
#'   (100-cl)% of read lengths falling in the tails of the read lengths
#'   distribution. If multiple samples are analysed, a single range of read
#'   lengths is computed such that at least the cl% of all samples is
#'   represented. Default is 100.
#' @param utr5l Positive integer specifying the length (in nucleotides) of the
#'   5' UTR region flanking the start codon to be considered in the analysis.
#'   Default is 50.
#' @param cdsl Positive integer specifying the length (in nucleotides) of the
#'   CDS regions flanking both the start and stop codon to be considered in the
#'   analysis. Default is 50.
#' @param utr3l Positive integer specifying the length (in nucleotides) of the
#'   3' UTR region flanking the stop codon to be considered in the analysis.
#'   Default is 50.
#' @param colour Character string specifying the colour of the plot. The colour
#'   scheme is as follow: tiles corresponding to the lowest signal are always
#'   white, tiles corresponding to the highest signal are of the specified
#'   colour and the progression between these two colours follows either linear
#'   or logarithmic gradients (see \code{log_colour}). Default is "black".
#' @param log_colour Logical value whether to use a logarithmic colour scale
#'   (strongly suggested in case of large signal variations). Default is FALSE.
#' @return List containing: one or more ggplot object(s) and the data table with
#'   the corresponding x- and y-axis values and the values defining the color of
#'   the tiles ("plot_dt"); an additional data table with raw and scaled number
#'   of read extremities mapping around the start and the stop codon, per
#'   length, for each sample ("count_dt").
#' @examples
#' data(reads_list)
#' data(mm81cdna)
#' 
#' ## Generate fake samples and replicates
#' for(i in 2:6){
#'   samp_name <- paste0("Samp", i)
#'   set.seed(i)
#'   reads_list[[samp_name]] <- reads_list[["Samp1"]][sample(.N, 5000)]
#' }
#' 
#' ## Define the list of samples and replicate to use as input
#' input_samples <- list("S1" = c("Samp1", "Samp2"),
#'                       "S2" = c("Samp3", "Samp4", "Samp5"),
#'                       "S3" = c("Samp6"))
#'
#' ## Generate metaheatmaps for a sub-range of read lengths:
#' example_ends_heatmap <- rends_heat(reads_list, mm81cdna,
#'                                    sample = input_samples,
#'                                    multisamples = "average",
#'                                    plot_style = "split",
#'                                    cl = 85,
#'                                    utr5l = 25, cdsl = 40, utr3l = 25)
#' @import data.table
#' @import ggplot2
#' @export
rends_heat <- function(data, annotation, sample, multisamples = "average",
                       plot_style = "split", scale_factors = "auto",
                       transcripts = NULL, length_range = NULL, cl = 100,
                       utr5l = 50, cdsl = 50, utr3l = 50, log_colour = F,
                       colour = "black") {
  
  if(class(data[[1]])[1] == "GRanges"){
    data_tmp <- list()
    for(i in unlist(sample)){
      data_tmp[[i]] <- as.data.table(data[[i]])[, c("width", "strand") := NULL
                                                ][, seqnames := as.character(seqnames)]
      setnames(data_tmp[[i]], c("seqnames", "start", "end"), c("transcript", "end5", "end3"))
    }
    data <- data_tmp
  }
  
  check_sample <- setdiff(unlist(sample), names(data))
  if(length(check_sample) != 0){
    cat("\n")
    stop(sprintf("incorrect sample name(s): \"%s\" not found\n\n",
                 paste(check_sample, collapse = ", ")))
  }
  
  if(length(sample) == 0){
    cat("\n")
    stop("at least one sample name must be spcified\n\n")
  }
  
  if(!(multisamples %in% c("average", "independent"))){
    cat("\n")
    warning("parameter multisamples must be either \"average\" or \"independent\".
            Set to default \"average\"\n", call. = FALSE)
    multisamples <- "average"
  }
  
  if(multisamples == "independent" & is.list(sample)) {
    cat("\n")
    warning("parameter multisamples is set to \"independent\" but parameter sample is a list:
            parameter multisamples will be coerced to default average\n", call. = FALSE)
    multisamples <- "average"
  }
  
  if(is.character(sample) & length(sample) == 1) {
    multisamples <- "independent"
    plot_style <- "split"
  }
  
  if(is.character(sample) & length(sample) > 1 & multisamples == "average") {
    sample <- list("Average" = sample)
    plot_style <- "split"
    cat("\n")
    warning("Default name of averaged samples is \"Average\":
            consider to use a named list of one element to provide a meaningful plot title\n", call. = FALSE)
  }

  if(is.list(sample) & length(sample) == 1){
    plot_style <- "split"
  }
  
  if(!(plot_style %in% c("split", "facet"))){
    cat("\n")
    warning("parameter plot_style must be either \"split\" or \"facet\".
            Set to default \"split\"\n", call. = FALSE)
    plot_style <- "split"
  }
  
  # select transcripts
  l_transcripts <- as.character(annotation[l_utr5 >= utr5l & 
                                             l_cds > 2 * (cdsl + 1) &
                                             l_utr3 >= utr3l, transcript])
  
  if (length(transcripts) == 0) {
    c_transcripts <- l_transcripts
  } else {
    c_transcripts <- intersect(l_transcripts, transcripts)
  }
  
  # define length range taking into account all (unlisted) samples
  if(length(length_range) == 0){
    for(samp in as.character(unlist(sample))){
      dt <- data[[samp]][transcript %in% c_transcripts]
      
      if(length(length_range) == 0){
        length_range <- seq(quantile(dt$length, (1 - cl/100)/2),
                            quantile(dt$length, 1 - (1 - cl/100)/2))
      } else {
        xmin <- min(min(length_range), quantile(dt$length, (1 - cl/100)/2))
        xmax <- max(max(length_range), quantile(dt$length, 1 - (1 - cl/100)/2))
        length_range <- seq(xmin, xmax)
      }
    }
  }
  
  minlen = min(length_range)
  maxlen = max(length_range)
  
  # check if all samples have reads of the specified lengths
  # especially required if only one read length is passed
  if(length(length_range) != 0){
    if(is.list(sample)){
      samp_dt <- data.table(stack(sample))
      setnames(samp_dt, c("sample", "sample_l"))
    } else {
      samp_dt <- data.table("sample" = sample, "sample_l" = sample)
    }
    
    for(samp in samp_dt$sample){
      
      dt <- data[[samp]][cds_start != 0 & cds_stop !=0]
      
      if(length(c_transcripts) != 0) {
        dt <- dt[transcript %in% c_transcripts]
      }
      
      len_check <- unique(dt$length)
      if(sum(length_range %in% len_check) == 0) {
        cat("\n")
        warning(sprintf("\"%s\" doesn't contain any reads of the selected lengths: sample removed\n", samp), call. = FALSE)
        #select element of sample which include the sample to be removed (useful if sample is a list)
        sel_l_samp <- samp_dt[sample == samp, sample_l]
        #remove the sample from the list/vector
        if(is.list(samp)){
          sample[[sel_l_samp]] <- sample[[sel_l_samp]][sample[[sel_l_samp]] != samp]
        } else {
          sample <- sample[sample != samp]
        }
      }
    }
  }
  
  if(is.null(unlist(sample))){
    cat("\n")
    stop("none of the data tables listed in sample contains any reads of the specified lengths\n\n")
  }
  
  # compute signal for heatmaps
  final_dt <- data.table()
  for(samp in as.character(unlist(sample))){
    dt <- data[[samp]][transcript %in% c_transcripts]
    
    # add distances of extremities from start and codons
    dt[, start_dist_end5 := end5 - cds_start
       ][, stop_dist_end5 := end5 - cds_stop
         ][, start_dist_end3 := end3 - cds_start
           ][, stop_dist_end3 := end3 - cds_stop]

    # 5' end signal from start
    start_sub <- dt[start_dist_end5 %in% seq(-utr5l, cdsl)]
    start_tab <- setkey(start_sub, length, start_dist_end5
    )[CJ(length_range, start_dist_end5 = seq(-utr5l, cdsl)), .N, by=.EACHI
      ][, region := "start"]
    
    # 5' end signal from stop
    stop_sub <- dt[stop_dist_end5 %in% seq(-cdsl, utr3l)]
    stop_tab <- setkey(stop_sub, length, stop_dist_end5
    )[CJ(length_range, stop_dist_end5 = seq(-cdsl, utr3l)), .N, by=.EACHI
      ][, region := "stop"]
    
    # merge 5' end tables
    final_tab5 <- rbind(start_tab, stop_tab, use.names = FALSE)[, end := "5end"]
    
    # 3' end signal from start
    start_sub <- dt[start_dist_end3 %in% seq(-utr5l, cdsl)]
    start_tab <- setkey(start_sub, length, start_dist_end3
    )[CJ(length_range, start_dist_end3 = seq(-utr5l, cdsl)), .N, by=.EACHI
      ][, region := "start"]
    
    # 3' end signal from stop
    stop_sub <- dt[stop_dist_end3 %in% seq(-cdsl, utr3l)]
    stop_tab <- setkey(stop_sub, length, stop_dist_end3
    )[CJ(length_range, stop_dist_end3 = seq(-cdsl, utr3l)), .N, by=.EACHI
      ][, region := "stop"]
    
    # merge 3' end tables 
    final_tab3 <- rbind(start_tab, stop_tab, use.names = FALSE)[, end := "3end"]
    
    # merge all
    samp_final_tab <- rbind(final_tab5, final_tab3, use.names = FALSE)
    setnames(samp_final_tab, c("length", "distance", "count", "region", "end"))
    
    # scaling/normalization
    if(is.character(scale_factors) & scale_factors[1] == "auto"){
      samp_final_tab[, scaled_count := (count / mean(count))]
    } else {
      if(is.numeric(scale_factors)){
        samp_final_tab[, scaled_count := count * scale_factors[samp]]
      } else {
        samp_final_tab[, scaled_count := count]
      }
    }

    samp_final_tab[, tmp_samp := samp]

    final_dt <- rbind(final_dt, samp_final_tab)
  }
  
  final_dt[, region := factor(region, levels = c("start", "stop"),
                              labels = c("Distance from start (nt)", "Distance from stop (nt)"))
           ][, end := factor(end, levels = c("5end", "3end"), labels = c("5' end", "3' end"))]
  
  output <- list()
  output[["count_dt"]] <- copy(final_dt[, c("tmp_samp", "region", "end", "length", "distance", "count", "scaled_count")])
  if(is.character(scale_factors) & scale_factors[1] == "none"){
    output[["count_dt"]][, scaled_count := NULL]
  }
  setnames(output[["count_dt"]], "tmp_samp", "sample")

  # compute mean samples if a list is provided
  if(is.list(sample)){
    
    samp_dt <- data.table(stack(sample))
    setnames(samp_dt, c("tmp_samp", "sample"))
    
    # set names of samples as specified in parameter sample  
    final_dt <- merge.data.table(final_dt, samp_dt, sort = FALSE)[, tmp_samp := NULL]
    
    # compute mean
    plot_dt <- final_dt[, .(mean_scaled_count = mean(scaled_count)), by = .(length, region, end, distance, sample)]
  } else {
    plot_dt <- final_dt[, sample := tmp_samp]
    setnames(plot_dt, "scaled_count", "mean_scaled_count")
  }
  
  output[["plot_dt"]] <- copy(plot_dt[, c("sample", "region", "end", "distance", "length", "mean_scaled_count")])
  setnames(output[["plot_dt"]], c("distance", "length", "mean_scaled_count"), c("x", "y", "z"))
  
  plot_dt[, sample := factor(sample, levels = unique(sample))]
  zmax <- max(plot_dt$mean_scaled_count)
  
  oldw <- getOption("warn")
  options(warn=-1)
  
  if(plot_style == "split"){ 
    for(samp in unique(plot_dt$sample)){ # generate a plot for each sample and store it
      sub_plot_dt <- plot_dt[sample == samp]
      
      plot <- ggplot(sub_plot_dt, aes(distance, length)) +
        geom_tile(aes(fill = mean_scaled_count)) +
        labs(title = samp, y = "Read length") +
        theme_bw(base_size = 22) +
        facet_grid(end ~ region, scales = "free", switch = "x") +
        theme(panel.grid.major.x = element_blank(), panel.grid.minor.x = element_blank(),
              panel.grid.major.y = element_blank(), panel.grid.minor.y = element_blank(),
              axis.title.x = element_blank(), plot.title = element_text(hjust = 0.5),
              strip.background = element_blank(), strip.placement = "outside") +
        scale_y_continuous(limits = c(minlen - 0.5, maxlen + 0.5),
                           breaks = seq(minlen + ((minlen) %% 2), maxlen, by = max(2,floor((maxlen - minlen)/7)))) +
        geom_vline(xintercept = 0, linetype = 2, color = "red")
      
      if(log_colour == F) {
        plot <- plot +
          scale_fill_gradient("# read\nextremities\n", low = "white", high = colour,
                              limits = c(0.1, zmax), breaks = c(0.1, zmax/2, zmax),
                              labels = c("0", floor(zmax/2), floor(zmax)), na.value = "white")
      } else {
        plot <- plot +
          scale_fill_gradient("# read\nextremities\n", low = "white", high = colour,
                              limits = c(0.1, zmax), breaks = c(0.1, 10^(log10(zmax)/2 - 0.5), floor(zmax)),
                              labels = c("0", floor(10^(log10(zmax)/2 - 0.5)), floor(zmax)), trans = "log", na.value = "transparent")
      }
      
      output[[paste0("plot_", samp)]] <- plot
    }
  } else {
    plot <- ggplot(plot_dt, aes(distance, length)) +
      geom_tile(aes(fill = mean_scaled_count)) +
      labs(y = "Read length") +
      theme_bw(base_size = 22) +
      theme(panel.grid.major.x = element_blank(), panel.grid.minor.x = element_blank(),
            panel.grid.major.y = element_blank(), panel.grid.minor.y = element_blank(),
            axis.title.x = element_blank()) +
      facet_grid(sample + end ~ region, scales = "free", switch = "x") +
      theme(strip.background = element_blank(), strip.placement = "outside") +
      theme(plot.title = element_text(hjust = 0.5)) +
      scale_y_continuous(limits = c(minlen - 0.5, maxlen + 0.5),
                         breaks = seq(minlen + ((minlen) %% 2),maxlen, by = max(2,floor((maxlen - minlen)/7)))) +
      geom_vline(xintercept = 0, linetype = 2, color = "red")
    
    if (log_colour == F) {
      plot <- plot +
        scale_fill_gradient("# read\nextremities\n", low = "white", high = colour,
                            limits = c(0.1, zmax), breaks = c(0.1, zmax/2, zmax),
                            labels = c("0", floor(zmax/2), floor(zmax)), na.value = "white")
    } else {
      plot <- plot +
        scale_fill_gradient("# read\nextremities\n", low = "white", high = colour,
                            limits = c(0.1, zmax), breaks = c(0.1, 10^(log10(zmax)/2 - 0.5), floor(zmax)),
                            labels = c("0", floor(10^(log10(zmax)/2 - 0.5)), floor(zmax)), trans = "log", na.value = "transparent")
    }
    
    output[["plot"]] <- plot
  }

  options(warn = oldw)
  return(output)
}
LabTranslationalArchitectomics/riboWaltz documentation built on Jan. 17, 2024, 12:18 p.m.