R/plot_heatmap.R

Defines functions plot_heatmap

Documented in plot_heatmap

#' Visualize a Distance or Similarity Matrix as a Heatmap with Clustering
#'
#' This function creates a heatmap from a square distance or similarity matrix.
#' If a similarity matrix is provided, it should first be converted to a distance matrix by the user.
#' The function supports hierarchical clustering, group annotations, row/column sampling (random or stratified),
#' and various customization options.
#'
#' @param dist_mat A square distance matrix (numeric matrix) or a \code{dist} object.
#' @param max_n Integer. Maximum number of observations (rows/columns) to display.
#'   If the matrix exceeds this size, a subset of \code{max_n} observations is selected.
#' @param group Optional vector or factor providing group labels for rows/columns, used for color annotation.
#' @param stratified_sampling Logical. If \code{TRUE} and \code{group} is provided, sampling is stratified by group.
#'   Each group will contribute at least one observation if possible. Default is \code{FALSE}.
#' @param main_title Optional character string specifying the main title of the heatmap.
#' @param palette Character string specifying the RColorBrewer palette for heatmap cells. Default is \code{"YlOrRd"}.
#' @param clustering_method Character string specifying the hierarchical clustering method,
#'   as accepted by \code{\link[stats]{hclust}} (e.g., \code{"complete"}, \code{"average"}, \code{"ward.D2"}).
#' @param cluster_rows Logical, whether to perform hierarchical clustering on rows. Default is \code{TRUE}.
#' @param cluster_cols Logical, whether to perform hierarchical clustering on columns. Default is \code{TRUE}.
#' @param fontsize_row Integer specifying the font size of row labels. Default is 10.
#' @param fontsize_col Integer specifying the font size of column labels. Default is 10.
#' @param show_rownames Logical, whether to display row names. Default is \code{TRUE}.
#' @param show_colnames Logical, whether to display column names. Default is \code{TRUE}.
#' @param border_color Color of the cell borders in the heatmap. Default is \code{"grey60"}.
#' @param annotation_legend Logical, whether to display the legend for group annotations. Default is \code{TRUE}.
#' @param seed Integer. Random seed used when sampling rows/columns if \code{max_n} is smaller than total observations. Default is 123.
#'
#' @details
#' The function works as follows:
#' \itemize{
#'   \item Converts \code{dist} objects to matrices automatically.
#'   \item Samples rows/columns if the matrix is larger than \code{max_n}. Sampling can be random or stratified by group.
#'   \item In stratified sampling mode, each group contributes at least one observation if possible.
#'   \item Supports row annotations for groups and automatically assigns colors.
#'   \item Uses \code{pheatmap} for plotting with customizable clustering, labels, fonts, and colors.
#' }
#'
#' This function is used internally by \code{\link{visualize_distances}()} but can be called directly for advanced usage.
#'
#' @return Invisibly returns the \code{pheatmap} object, allowing further customization if assigned.
#'
#' @examples
#' # Example: Euclidean distance heatmap on iris
#' eucli_dist <- stats::dist(iris[, 1:4])
#' dbrobust:::plot_heatmap(
#'   dist_mat = eucli_dist,
#'   max_n = 10,
#'   group = iris$Species,
#'   stratified_sampling = TRUE,
#'   main_title = "Euclidean Distance Heatmap",
#'   palette = "YlOrRd",
#'   clustering_method = "complete"
#' )
#'
#' # Example: GGower distances with small subset
#' data("Data_HC_contamination", package = "dbrobust")
#' Data_small <- Data_HC_contamination[1:50, ]
#' cont_vars <- c("V1", "V2", "V3", "V4")
#' cat_vars  <- c("V5", "V6", "V7")
#' bin_vars  <- c("V8", "V9")
#' w <- Data_small$w_loop
#' dist_sq_ggower <- dbrobust:::robust_distances(
#'   data = Data_small,
#'   cont_vars = cont_vars,
#'   bin_vars  = bin_vars,
#'   cat_vars  = cat_vars,
#'   w = w,
#'   alpha = 0.10,
#'   method = "ggower"
#' )
#' group_vec <- rep("Normal", nrow(dist_sq_ggower))
#' group_vec[attr(dist_sq_ggower, "outlier_idx")] <- "Outlier"
#' group_factor <- factor(group_vec, levels = c("Normal", "Outlier"))
#' dbrobust:::plot_heatmap(
#'   dist_mat = sqrt(dist_sq_ggower),
#'   max_n = 20,
#'   group = group_factor,
#'   main_title = "GGower Heatmap with Outliers",
#'   palette = "YlOrRd",
#'   clustering_method = "complete",
#'   annotation_legend = TRUE,
#'   stratified_sampling = TRUE,
#'   seed = 123
#' )
#'
#' @seealso
#' \code{\link[stats]{hclust}} for hierarchical clustering methods.
#' \code{\link[pheatmap]{pheatmap}} for additional heatmap customization options.
#' \code{\link[RColorBrewer]{brewer.pal}} for available color palettes.
#' @importFrom pheatmap pheatmap
#' @importFrom stats setNames
#' @importFrom RColorBrewer brewer.pal
#' @keywords internal
plot_heatmap <- function(
    dist_mat,
    max_n = 50,
    group = NULL,
    stratified_sampling = FALSE,
    main_title = NULL,
    palette = "YlOrRd",
    clustering_method = "complete",
    cluster_rows = TRUE,
    cluster_cols = TRUE,
    fontsize_row = 10,
    fontsize_col = 10,
    show_rownames = TRUE,
    show_colnames = TRUE,
    border_color = "grey60",
    annotation_legend = TRUE,
    seed = 123
) {
  if (missing(dist_mat)) stop("Argument 'dist_mat' is required.")

  # Check matrix
  if (!is.matrix(dist_mat) && !inherits(dist_mat, "dist")) {
    stop("dist_mat must be a matrix or a dist object.")
  }
  if (inherits(dist_mat, "dist")) dist_mat <- as.matrix(dist_mat)
  if (nrow(dist_mat) != ncol(dist_mat)) stop("dist_mat must be square.")

  # Ensure row/column names exist
  if (is.null(rownames(dist_mat))) rownames(dist_mat) <- as.character(seq_len(nrow(dist_mat)))
  if (is.null(colnames(dist_mat))) colnames(dist_mat) <- as.character(seq_len(ncol(dist_mat)))

  n <- nrow(dist_mat)

  # Preserve group levels
  if (!is.null(group)) {
    group <- factor(group)
    orig_levels <- levels(group)
  }

  # Sampling if too large
  if (n > max_n) {
    set.seed(seed)
    if (!is.null(group) && stratified_sampling) {
      # Stratified sampling
      group_sizes <- table(group)
      # Calculate number per group proportionally
      n_per_group <- pmax(round((group_sizes / sum(group_sizes)) * max_n), 1)

      sampled_idx <- unlist(lapply(levels(group), function(g) {
        idx <- which(group == g)
        n_samp <- min(length(idx), n_per_group[g])
        sample(idx, n_samp)
      }))
      sampled_idx <- sort(sampled_idx)
    } else {
      # Simple random sampling
      sampled_idx <- sort(sample(seq_len(n), max_n))
    }
    dist_mat <- dist_mat[sampled_idx, sampled_idx]
    if (!is.null(group)) group <- factor(group[sampled_idx], levels = orig_levels)
    message(sprintf("Matrix has %d observations; sampling %d for heatmap.", n, length(sampled_idx)))
  }

  # Row annotations
  annotation_row <- NULL
  annotation_colors <- NULL
  if (!is.null(group)) {
    if (length(group) != nrow(dist_mat)) stop("Length of group must match number of observations after sampling.")
    annotation_row <- data.frame(Group = factor(group, levels = orig_levels))
    rownames(annotation_row) <- rownames(dist_mat)
    n_groups <- length(levels(annotation_row$Group))
    base_colors <- get_custom_palette(n_groups)
    annotation_colors <- list(Group = setNames(base_colors, levels(annotation_row$Group)))
  }

  # Color palette
  colors <- rev(grDevices::colorRampPalette(RColorBrewer::brewer.pal(9, palette))(100))

  # Draw heatmap
  p <- pheatmap::pheatmap(
    mat = dist_mat,
    clustering_method = clustering_method,
    cluster_rows = cluster_rows,
    cluster_cols = cluster_cols,
    annotation_row = annotation_row,
    annotation_colors = annotation_colors,
    main = main_title,
    color = colors,
    border_color = border_color,
    fontsize_row = fontsize_row,
    fontsize_col = fontsize_col,
    show_rownames = show_rownames,
    show_colnames = show_colnames,
    legend = annotation_legend
  )

  invisible(p)
}

Try the dbrobust package in your browser

Any scripts or data that you put into this service are public.

dbrobust documentation built on Nov. 5, 2025, 6:24 p.m.