Nothing
#' Visualize a Distance or Similarity Matrix as a Heatmap with Clustering
#'
#' This function creates a heatmap from a square distance or similarity matrix.
#' If a similarity matrix is provided, it should first be converted to a distance matrix by the user.
#' The function supports hierarchical clustering, group annotations, row/column sampling (random or stratified),
#' and various customization options.
#'
#' @param dist_mat A square distance matrix (numeric matrix) or a \code{dist} object.
#' @param max_n Integer. Maximum number of observations (rows/columns) to display.
#' If the matrix exceeds this size, a subset of \code{max_n} observations is selected.
#' @param group Optional vector or factor providing group labels for rows/columns, used for color annotation.
#' @param stratified_sampling Logical. If \code{TRUE} and \code{group} is provided, sampling is stratified by group.
#' Each group will contribute at least one observation if possible. Default is \code{FALSE}.
#' @param main_title Optional character string specifying the main title of the heatmap.
#' @param palette Character string specifying the RColorBrewer palette for heatmap cells. Default is \code{"YlOrRd"}.
#' @param clustering_method Character string specifying the hierarchical clustering method,
#' as accepted by \code{\link[stats]{hclust}} (e.g., \code{"complete"}, \code{"average"}, \code{"ward.D2"}).
#' @param cluster_rows Logical, whether to perform hierarchical clustering on rows. Default is \code{TRUE}.
#' @param cluster_cols Logical, whether to perform hierarchical clustering on columns. Default is \code{TRUE}.
#' @param fontsize_row Integer specifying the font size of row labels. Default is 10.
#' @param fontsize_col Integer specifying the font size of column labels. Default is 10.
#' @param show_rownames Logical, whether to display row names. Default is \code{TRUE}.
#' @param show_colnames Logical, whether to display column names. Default is \code{TRUE}.
#' @param border_color Color of the cell borders in the heatmap. Default is \code{"grey60"}.
#' @param annotation_legend Logical, whether to display the legend for group annotations. Default is \code{TRUE}.
#' @param seed Integer. Random seed used when sampling rows/columns if \code{max_n} is smaller than total observations. Default is 123.
#'
#' @details
#' The function works as follows:
#' \itemize{
#' \item Converts \code{dist} objects to matrices automatically.
#' \item Samples rows/columns if the matrix is larger than \code{max_n}. Sampling can be random or stratified by group.
#' \item In stratified sampling mode, each group contributes at least one observation if possible.
#' \item Supports row annotations for groups and automatically assigns colors.
#' \item Uses \code{pheatmap} for plotting with customizable clustering, labels, fonts, and colors.
#' }
#'
#' This function is used internally by \code{\link{visualize_distances}()} but can be called directly for advanced usage.
#'
#' @return Invisibly returns the \code{pheatmap} object, allowing further customization if assigned.
#'
#' @examples
#' # Example: Euclidean distance heatmap on iris
#' eucli_dist <- stats::dist(iris[, 1:4])
#' dbrobust:::plot_heatmap(
#' dist_mat = eucli_dist,
#' max_n = 10,
#' group = iris$Species,
#' stratified_sampling = TRUE,
#' main_title = "Euclidean Distance Heatmap",
#' palette = "YlOrRd",
#' clustering_method = "complete"
#' )
#'
#' # Example: GGower distances with small subset
#' data("Data_HC_contamination", package = "dbrobust")
#' Data_small <- Data_HC_contamination[1:50, ]
#' cont_vars <- c("V1", "V2", "V3", "V4")
#' cat_vars <- c("V5", "V6", "V7")
#' bin_vars <- c("V8", "V9")
#' w <- Data_small$w_loop
#' dist_sq_ggower <- dbrobust:::robust_distances(
#' data = Data_small,
#' cont_vars = cont_vars,
#' bin_vars = bin_vars,
#' cat_vars = cat_vars,
#' w = w,
#' alpha = 0.10,
#' method = "ggower"
#' )
#' group_vec <- rep("Normal", nrow(dist_sq_ggower))
#' group_vec[attr(dist_sq_ggower, "outlier_idx")] <- "Outlier"
#' group_factor <- factor(group_vec, levels = c("Normal", "Outlier"))
#' dbrobust:::plot_heatmap(
#' dist_mat = sqrt(dist_sq_ggower),
#' max_n = 20,
#' group = group_factor,
#' main_title = "GGower Heatmap with Outliers",
#' palette = "YlOrRd",
#' clustering_method = "complete",
#' annotation_legend = TRUE,
#' stratified_sampling = TRUE,
#' seed = 123
#' )
#'
#' @seealso
#' \code{\link[stats]{hclust}} for hierarchical clustering methods.
#' \code{\link[pheatmap]{pheatmap}} for additional heatmap customization options.
#' \code{\link[RColorBrewer]{brewer.pal}} for available color palettes.
#' @importFrom pheatmap pheatmap
#' @importFrom stats setNames
#' @importFrom RColorBrewer brewer.pal
#' @keywords internal
plot_heatmap <- function(
dist_mat,
max_n = 50,
group = NULL,
stratified_sampling = FALSE,
main_title = NULL,
palette = "YlOrRd",
clustering_method = "complete",
cluster_rows = TRUE,
cluster_cols = TRUE,
fontsize_row = 10,
fontsize_col = 10,
show_rownames = TRUE,
show_colnames = TRUE,
border_color = "grey60",
annotation_legend = TRUE,
seed = 123
) {
if (missing(dist_mat)) stop("Argument 'dist_mat' is required.")
# Check matrix
if (!is.matrix(dist_mat) && !inherits(dist_mat, "dist")) {
stop("dist_mat must be a matrix or a dist object.")
}
if (inherits(dist_mat, "dist")) dist_mat <- as.matrix(dist_mat)
if (nrow(dist_mat) != ncol(dist_mat)) stop("dist_mat must be square.")
# Ensure row/column names exist
if (is.null(rownames(dist_mat))) rownames(dist_mat) <- as.character(seq_len(nrow(dist_mat)))
if (is.null(colnames(dist_mat))) colnames(dist_mat) <- as.character(seq_len(ncol(dist_mat)))
n <- nrow(dist_mat)
# Preserve group levels
if (!is.null(group)) {
group <- factor(group)
orig_levels <- levels(group)
}
# Sampling if too large
if (n > max_n) {
set.seed(seed)
if (!is.null(group) && stratified_sampling) {
# Stratified sampling
group_sizes <- table(group)
# Calculate number per group proportionally
n_per_group <- pmax(round((group_sizes / sum(group_sizes)) * max_n), 1)
sampled_idx <- unlist(lapply(levels(group), function(g) {
idx <- which(group == g)
n_samp <- min(length(idx), n_per_group[g])
sample(idx, n_samp)
}))
sampled_idx <- sort(sampled_idx)
} else {
# Simple random sampling
sampled_idx <- sort(sample(seq_len(n), max_n))
}
dist_mat <- dist_mat[sampled_idx, sampled_idx]
if (!is.null(group)) group <- factor(group[sampled_idx], levels = orig_levels)
message(sprintf("Matrix has %d observations; sampling %d for heatmap.", n, length(sampled_idx)))
}
# Row annotations
annotation_row <- NULL
annotation_colors <- NULL
if (!is.null(group)) {
if (length(group) != nrow(dist_mat)) stop("Length of group must match number of observations after sampling.")
annotation_row <- data.frame(Group = factor(group, levels = orig_levels))
rownames(annotation_row) <- rownames(dist_mat)
n_groups <- length(levels(annotation_row$Group))
base_colors <- get_custom_palette(n_groups)
annotation_colors <- list(Group = setNames(base_colors, levels(annotation_row$Group)))
}
# Color palette
colors <- rev(grDevices::colorRampPalette(RColorBrewer::brewer.pal(9, palette))(100))
# Draw heatmap
p <- pheatmap::pheatmap(
mat = dist_mat,
clustering_method = clustering_method,
cluster_rows = cluster_rows,
cluster_cols = cluster_cols,
annotation_row = annotation_row,
annotation_colors = annotation_colors,
main = main_title,
color = colors,
border_color = border_color,
fontsize_row = fontsize_row,
fontsize_col = fontsize_col,
show_rownames = show_rownames,
show_colnames = show_colnames,
legend = annotation_legend
)
invisible(p)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.