R/heatmaps.R

Defines functions .tss_heatmap

## TSS Heatmap Count Matrix
##
## Generate count matrix to make TSS heatmap.
##
## @param annotated_data Annotated data.table
## @param upstream Bases upstream of plot center
## @param downstream Bases downstream of plot center

.tss_heatmap <- function(
  annotated_data,
  upstream=1000,
  downstream=1000
) {

  ## Create matrix.
  annotated_data <- annotated_data[, .(sample, score, distanceToTSS, feature, FHASH)]

  # Cross-join so that all TSS distances are generated.
  tss_mat <- annotated_data[
    CJ(sample=sample, feature=feature, distanceToTSS=seq(-upstream, downstream, 1), unique=TRUE),
    , on=.(sample, feature, distanceToTSS)
  ]
  setnafill(tss_mat, col="score", fill=0)
  tss_mat[,
    distanceToTSS := factor(distanceToTSS, levels=seq(-upstream, downstream, 1))
  ]
  
  return(tss_mat)
}

#' Plot Heatmap
#'
#' @description
#' Plot heatmap from count matrix generated by tss_heatmap_matrix or tsr_heatmap_matrix
#'
#' @importFrom purrr keep
#'
#' @inheritParams common_params
#' @param upstream Bases upstream to consider
#' @param downstream bases downstream to consider
#' @param ... Additional arguments passed to Heatmap
#' @param max_value Truncate heatmap scale at this value.
#' @param low_color Color for minimum value.
#' @param high_color Color for maximum value.
#' @param log2_transform Log2 + 1 transform values for plotting.
#' @param x_axis_breaks The distance breaks to show values on the x-axis.
#' @param filtering Logical statement by which to filter data.
#' @param ordering Symbol/name specifying the column by which to order.
#' @param order_fun Function to aggregate variable by before ordering.
#' @param order_descending Whether to order in descending (TRUE) order.
#' @param order_samples Samples that are used to calculate ordering.
#' @param quantiling Character specifying column by which to quantile..
#' @param quantile_fun Functiont o aggregate variable by before quantiling.
#' @param n_quantiles Number of quantiles.
#' @param quantile_samples Samples to use for quantiling.
#' @param remove_antisense Remove antisense reads.
#' @param split_by Named list with split group as name and vector of genes,
#'   or data.frame with columns 'feature' and 'split_group'.
#' @param data_type Plot TSS ('tss') or TSR ('tsr') scores.
#' @param diff_heatmap_list Named list if sample pairs.
#'   The name will be the comparison name,
#"   and each list element should be the two samples to compare.
#'
#' @details
#' This plotting function generates a ggplot2 heatmap of TSS or TSR signal
#'   surrounding the annotated TSSs of genes or transcripts.
#' Whether genes or transcripts are used depends on the feature type chosen
#'   when annotating the TSSs with the 'annotate_features' function. 
#'
#' The region around the annotated TSS used for plotting is controlled by
#'   'upstream' and 'downstream', which should be positive integers.
#'
#' A set of arguments to control data structure for plotting are included.
#' 'use_normalized' will use the normalized scores as opposed to raw read counts.
#' 'threshold' definites the minimum number of reads a TSS or TSR
#'  must have to be considered.
#' 'dominant' specifies whether only the dominant TSS or TSR is considered 
#'   from the 'mark_dominant' function.
#' For TSSs this can be either dominant per TSR or gene, and for TSRs
#'   it is just the dominant TSR per gene.
#'
#' A set of arguments for data conditions are supplied seperatly from
#'   the 'conditionals' function used in many other core functions.
#' This is because each row (feature) can have multiple TSSs or TSRs,
#'   which is unique to this type of plot.
#' 'filtering' can be supplied with a logical statement to filter TSSs and TSRs
#'   by the given condition(s).
#' 'ordering' can be supplied with a symbol/name of the variable to order by,
#'   and 'order_descending' controls ordering direction.
#' 'order_fun' is the function used to aggregate the variable score for each row/feature,
#'   and 'order_sample' controls the samples used to order from these aggregated variables.
#' 'quantiling' is a character specifying the numeric variable to quantile by,
#'   and 'n_quantiles' controls the number of quantiles to split the data into.
#' Just as with ordering, 'quantiles_fun' is the function to aggregate the numeric
#'   variable by per feature/row, and 'quantile_samples' are the samples used to
#'   determine the order.
#' Finally, 'split_by' can be given either a two column data.frame ('feature' and 'split_group'),
#'   or a named list, where the names are the split category and the list contents are
#'   a vector of genes.
#'
#' An option to rasterize the heatmaps using ggrastr is provided with the 'rasterize' argument,
#'   and the DPI (resolution) is controlled by 'raster_dpi'.
#'
#' If diff_heatmap_list is given, the heatmaps will represent the subtracted
#'   score between the sample pairs provided in the list.
#' If this argument is given the only data conditionals that will work are ordering related.
#'
#' @return ggplot2 object of TSS or TSR heatmap
#'
#' @seealso
#' \code{\link{annotate_features}} to annotate the TSSs or TSRs.
#'
#' @examples
#' data(TSSs_reduced)
#' annotation <- system.file("extdata", "S288C_Annotation.gtf", package="TSRexploreR")
#'
#' exp <- TSSs_reduced %>%
#'   tsr_explorer(genome_annotation=annotation) %>%
#'   format_counts(data_type="tss") %>%
#'   annotate_features(data_type="tss")
#'
#' p <- plot_heatmap(exp, data_type="tss")
#'
#' @export

plot_heatmap <- function(
  experiment,
  samples="all",
  data_type=c("tss", "tsr"),
  upstream=1000,
  downstream=1000,
  threshold=NULL,
  use_normalized=FALSE,
  dominant=FALSE,
  remove_antisense=TRUE,
  rasterize=FALSE,
  raster_dpi=150,
  max_value=NULL,
  low_color="white",
  high_color="blue",
  log2_transform=TRUE,
  x_axis_breaks=100,
  ncol=3,
  filtering=NULL,
  ordering=score,
  order_descending=TRUE,
  order_fun=sum,
  order_samples=NULL,
  quantiling=NULL,
  quantile_fun=sum,
  n_quantiles=5,
  quantile_samples=NULL,
  split_by=NULL,
  diff_heatmap_list=NULL,
  ...
) {

  ## Check inputs.
  assert_that(is(experiment, "tsr_explorer"))
  assert_that(is.character(samples))
  assert_that(is.count(upstream))
  assert_that(is.count(downstream))
  assert_that(is.null(threshold) || (is.numeric(threshold) && threshold >= 0))
  assert_that(is.flag(use_normalized))
  assert_that(is.flag(dominant))
  data_type <- match.arg(str_to_lower(data_type), c("tss", "tsr"))
  assert_that(is.flag(rasterize))
  assert_that(is.count(raster_dpi))
  assert_that(
    is.null(max_value) ||
    (is.numeric(max_value) && max_value > 0)
  )
  assert_that(is.character(low_color))
  assert_that(is.character(high_color))
  assert_that(is.flag(log2_transform))
  assert_that(is.count(ncol))
  assert_that(is.flag(order_descending))
  assert_that(is.function(order_fun))
  assert_that(is.null(order_samples) || is.character(order_samples))
  assert_that(is.function(quantile_fun))
  assert_that(is.count(n_quantiles))
  assert_that(is.null(quantile_samples) || is.character(quantile_samples))
  assert_that(
    is.null(split_by) ||
    (is.list(split_by) && has_attr(split_by, "names")) ||
    (is.data.frame(split_by) && colnames(split_by) %in% c("feature", "split_group"))
  )
  assert_that(
    is.null(diff_heatmap_list) ||
    (is.list(diff_heatmap_list) &&
    has_attr(diff_heatmap_list, "names") &&
    all(lengths(diff_heatmap_list) == 2))
  )

  ## Check if ggrastr is installed if rasterization requested.
  if (rasterize) {
    if (!requireNamespace("ggrastr", quietly = TRUE)) {
      stop("Package \"ggrastr\" needed for this function to work. Please install it.",
        call. = FALSE)
    }
  }

  ## Get requested samples.
  annotated <- experiment %>%
    extract_counts(data_type, samples, use_normalized) %>%
    preliminary_filter(dominant, threshold)

  ## Filter data if data condition set.
  if (!quo_is_null(enquo(filtering))) {
    filtering <- enquo(filtering)
    annotated <- .filter_heatmap(annotated, filtering)
  }

  ## Remove antisense TSSs/TSRs.
  if (remove_antisense) {
    annotated <- map(annotated, ~.x[simple_annotations != "Antisense"])
  }

  ## Rename feature ID.
  walk(annotated, function(x) {
    setnames(x,
      old=ifelse(
        experiment@settings$annotation[, feature_type] == "transcript",
        "transcriptId", "geneId"
      ),
      new="feature"
    )
  })
  annotated <- rbindlist(annotated, idcol="sample")

  ## Create the count matrix.
  count_mat <- switch(
    data_type,
    "tss"=.tss_heatmap(annotated, upstream, downstream),
    "tsr"=.tsr_heatmap(annotated, upstream, downstream)
  )

  ## Log2+1 transform counts if requested and if making a diff heatmap.
  if (!is.null(diff_heatmap_list) & log2_transform) {
    count_mat[, score := log2(score + 1)]
  }

  ## Generate differential heatmap if set.
  if (!is.null(diff_heatmap_list)) {
    count_mat <- map(diff_heatmap_list, function(x) {
      mat <- count_mat[sample %in% x]
      mat <- mat[, .(sample, score, distanceToTSS, feature)]
      mat <- dcast(mat, ... ~ sample, value.var="score")
      mat[, score := get(x[2]) - get(x[1])]
      mat[, c(x) := NULL]
      return(mat)
    })
    count_mat <- rbindlist(count_mat, idcol="sample")
  }

  ## Quantile and/or order if required.

  # Apply ordering if set.
  if (!quo_is_null(enquo(ordering)) & is.null(diff_heatmap_list)) {
    ordering <- enquo(ordering)
    count_mat <- .order_heatmap(
      count_mat, data_type, annotated,
      ordering, order_fun, order_descending,
      order_samples
    )
  } else if (!quo_is_null(enquo(ordering)) & !is.null(diff_heatmap_list)) {
    ordering <- enquo(ordering)
    count_mat <- .order_diff_heatmap(
      count_mat, data_type, order_fun,
      order_descending, order_samples
    )
  }

  # Apply quantiling if set.
  if (!quo_is_null(enquo(quantiling)) & is.null(diff_heatmap_list)) {
    quantiling <- enquo(quantiling)
    count_mat <- .quantile_heatmap(
      count_mat, data_type, annotated,
      quantiling, quantile_fun, n_quantiles,
      quantile_samples
    )
  }

  # Change factor level of features if ordering.
  if (any(colnames(count_mat) == "row_order")) {
    count_mat[, feature := fct_rev(fct_reorder(feature, row_order))]
  }

  ## Use custom gene groups if set.
  if (!is.null(split_by) & is.null(diff_heatmap_list)) {
    count_mat <- .split_heatmap(count_mat, split_by)
  }

  ## Log2 + 1 transform data if set.
  if (is.null(diff_heatmap_list) & log2_transform) {
    count_mat[, score := log2(score + 1)]
  }

  ## Set sample order if required.
  if (!all(samples == "all") & is.null(diff_heatmap_list)) {
    count_mat[, sample := factor(sample, levels=samples)]
  } else if (!all(samples == "all") & !is.null(diff_heatmap_list)) {
    count_mat[, sample := factor(sample, levels=names(diff_heatmap_list))]
  }

  ## Create heatmap.
  p <- ggplot(count_mat, aes(x=.data$distanceToTSS, y=.data$feature))

  # Apply rasterization if required.
  if (rasterize) {
    p <- p + ggrastr::rasterize(
      geom_tile(aes(fill=.data$score, color=.data$score), ...),
      dpi=raster_dpi
    )
  } else {
    p <- p + geom_tile(aes(fill=.data$score, color=.data$score), ...)
  }

  p <- p +
    geom_vline(xintercept=upstream, color="black", linetype="dashed", size=0.1) +
    scale_x_discrete(
      breaks=seq(-upstream, downstream, 1) %>% keep(~ (./x_axis_breaks) %% 1 == 0),
      labels=seq(-upstream, downstream, 1) %>% keep(~ (./x_axis_breaks) %% 1 == 0)
    )

  # If diff_heatmap is set, set min value to be min value in heatmap.
  if (!is.null(diff_heatmap_list)) {
    min_value <- min(count_mat[["score"]])
  } else {
    min_value <- 0
  }

  # Truncate scale if max_value is set.
  if (!is.null(max_value) & is.null(diff_heatmap_list)) {
    p <- p +
      scale_fill_continuous(
        limits=c(min_value, max_value),
        breaks=seq(min_value, max_value, 1),
        labels=c(seq(min_value, max_value - 1, 1), paste0(">=", max_value)),
        name="Log2(Score)",
        low=low_color,
        high=high_color
      ) +
      scale_color_continuous(
        limits=c(min_value, max_value),
        breaks=seq(min_value, max_value, 1),
        labels=c(seq(min_value, max_value - 1, 1), paste0(">=", max_value)),
        name="Log2(Score)",
        low=low_color,
        high=high_color
      )
  } else if (is.null(max_value) & is.null(diff_heatmap_list)) {
    p <- p +
      scale_fill_continuous(
        name="Log2(Score)",
        low=low_color,
        high=high_color
      ) +
      scale_color_continuous(
        name="Log2(Score)",
        low=low_color,
        high=high_color
      )
  } else if (!is.null(max_value) & !is.null(diff_heatmap_list)) {
    p <- p +
      scale_fill_gradient2(
        limits=c(min_value, max_value),
        breaks=seq(min_value, max_value, 1),
        labels=c(seq(min_value, max_value - 1, 1), paste0(">=", max_value)),
        name="Log2(Score)",
        low=low_color,
        high=high_color,
        mid="white",
        midpoint=0
      ) +
      scale_color_gradient2(
        limits=c(min_value, max_value),
        breaks=seq(min_value, max_value, 1),
        labels=c(seq(min_value, max_value - 1, 1), paste0(">=", max_value)),
        name="Log2(Score)",
        low=low_color,
        high=high_color,
        mid="white",
        midpoint=0
      )
  } else if (is.null(max_value) & !is.null(diff_heatmap_list)) {
    p <- p +
      scale_fill_gradient2(
        name="Log2(Score)",
        low=low_color,
        high=high_color,
        mid="white",
        midpoint=0
      ) +
      scale_color_gradient2(
        name="Log2(Score)",
        low=low_color,
        high=high_color,
        mid="white",
        midpoint=0
      )
  }

  p <- p +
    ggplot2::theme_bw() +
    theme(
      axis.text.x=element_text(angle=45, hjust=1),
      panel.spacing=unit(1.5, "lines"),
      axis.text.y=element_blank(),
      axis.ticks.y=element_blank(),
      panel.grid=element_blank(),
      panel.background=element_rect(fill="white", color="white")
    ) +
    labs(x="Position", y="Feature")

  if (any(colnames(count_mat) == "row_quantile")) {
    p <- p + facet_wrap(row_quantile ~ sample, scales="free", ncol=ncol)
  } else if (any(colnames(count_mat) == "row_group")) {
    p <- p + facet_wrap(row_group ~ sample, scales="free", ncol=ncol)
  } else {
    p <- p + facet_wrap(sample ~ ., scales="free", ncol=ncol)
  }

  return(p)

}

## TSR Heatmap Count Matrix
##
## Generate count matrix to make TSR heatmap
##
## @param annotated_data Annotated data.table
## @param upstream Bases upstream to consider
## @param downstream bases downstream to consider

.tsr_heatmap <- function(
  annotated_data,
  upstream=1000,
  downstream=1000
) {

  ## Prepare data for plotting.
  annotated_data[,
    c("startDist", "endDist", "tsr_id") := list(
      ifelse(strand == "+", start - geneStart, (geneEnd - end)),
      ifelse(strand == "+", end - geneStart, (geneEnd - start)),
      seq_len(.N)
    )
  ]
  ## Put TSR score for entire range of TSR.
  new_ranges <- annotated_data[,
    .(sample, feature, score, FHASH,
    distanceToTSS=seq(as.numeric(startDist), as.numeric(endDist), 1)),
    by=tsr_id
  ]
  new_ranges[, tsr_id := NULL]
  new_ranges <- new_ranges[distanceToTSS >= -upstream & distanceToTSS <= downstream]

  ## Put score of 0 for ranges without TSR.
  new_ranges <- new_ranges[
    CJ(sample=sample, feature=feature, distanceToTSS=seq(-upstream, downstream, 1), unique=TRUE),
    , on=.(sample, feature, distanceToTSS)
  ]
  setnafill(new_ranges, cols="score", fill=0)
  new_ranges[, distanceToTSS := factor(distanceToTSS, levels=seq(-upstream, downstream, 1))]

  return(new_ranges)
}

## Filter Heatmap.
##
## @param sample_list List of sample data.
## @param filtering Quosure of filters.

.filter_heatmap <- function(
  sample_list,
  filtering
) {
  sample_list <- map(sample_list, ~dplyr::filter(.x, !!filtering))
  return(sample_list)
}

## Order Heatmap.
##
## @inheritParams plot_heatmap
## @param count_data data.table of sample data.
## @param annotated_data Annotated sample data.
## @param data_type Either 'tss' or 'tsr'.

.order_heatmap <- function(
  count_data,
  data_type,
  annotated_data,
  ordering,
  order_fun,
  order_descending,
  order_samples
) {

  an_data <- copy(annotated_data)
  an_data[, c("score", "feature", "distanceToTSS") := NULL]

  if (data_type == "tss") {
    merged <- copy(count_data)
  } else if (data_type == "tsr") {
    merged <- copy(count_data)
    merged[, distanceToTSS := NULL]
    merged <- unique(merged)
  }

  if (!is.null(order_samples)) {
    merged <- merged[sample %in% order_samples]
  }

  merged <- merge(merged, an_data, by=c("FHASH", "sample"))

  merged <- merged %>%
    dplyr::group_by(feature) %>%
    dplyr::summarize(aggr_var=order_fun(!!ordering))

  setDT(merged)
  if (order_descending) {
    merged <- merged[order(-aggr_var)]
  } else {
    merged <- merged[order(aggr_var)]
  }

  merged[, row_order := .I]
  merged[, aggr_var := NULL]

  merged <- merge(count_data, merged, by="feature")

  return(merged)
}

## Quantile Heatmap
##
## @inheritParams plot_heatmap
## @param count_data data.table of sample data.
## @param annotated_data Annotated data.

.quantile_heatmap <- function(
  count_data,
  data_type,
  annotated_data,
  quantiling,
  quantile_fun,
  n_quantiles,
  quantile_samples
) {

  an_data <- copy(annotated_data)
  an_data[, c("score", "feature", "distanceToTSS") := NULL]

  if (data_type == "tss") {
    merged <- copy(count_data)
  } else if (data_type == "tsr") {
    merged <- copy(count_data)
    merged[, distanceToTSS := NULL]
    merged <- unique(merged)
  }

  if (!is.null(quantile_samples)) {
    merged <- merged[sample %in% quantile_samples]
  }

  merged <- merge(count_data, an_data, by=c("FHASH", "sample"))

  merged <- merged %>%
    dplyr::group_by(feature) %>%
    dplyr::summarize(aggr_var=quantile_fun(!!quantiling))

  setDT(merged)
  merged[, row_quantile := ntile(aggr_var, n=n_quantiles)]
  merged[, row_quantile := fct_rev(factor(row_quantile))]
  merged[, aggr_var := NULL]

  merged <- merge(count_data, merged, by="feature")

  return(merged)

}

## Split Heatmap
##
## @inheritParams plot_heatmap
## @param count_mat Count matrix.

.split_heatmap <- function(
  count_mat,
  split_by
) {

  ## Change list to data.table if list provided.
  if (!is.data.frame(split_by)) {
    split_by <- map(split_by, ~data.table(feature=.x))
    split_by <- rbindlist(split_by, idcol="split_group")
  } else {
    setDT(split_by)
  }

  ## Merge split groups into data.
  setnames(split_by, old="split_group", new="row_group")
  count_mat <- merge(count_mat, split_by, by="feature")

  return(count_mat)
}

## Order Diff Heatmap.
##
## @inheritParams plot_heatmap
## @param count_data data.table of sample data.
## @param annotated_data Annotated sample data.
## @param data_type Either 'tss' or 'tsr'.

.order_diff_heatmap <- function(
  count_data,
  data_type,
  order_fun,
  order_descending,
  order_samples
) {

  if (data_type == "tss") {
    merged <- copy(count_data)
  } else if (data_type == "tsr") {
    merged <- copy(count_data)
    merged[, distanceToTSS := NULL]
    merged <- unique(merged)
  }

  if (!is.null(order_samples)) {
    merged <- merged[sample %in% order_samples]
  }

  merged <- merged %>%
    dplyr::group_by(feature) %>%
    dplyr::summarize(aggr_var=order_fun(score))

  setDT(merged)
  if (order_descending) {
    merged <- merged[order(-aggr_var)]
  } else {
    merged <- merged[order(aggr_var)]
  }

  merged[, row_order := .I]
  merged[, aggr_var := NULL]

  merged <- merge(count_data, merged, by="feature")

  return(merged)
}
zentnerlab/TSRexploreR documentation built on Dec. 30, 2022, 10:27 p.m.