R/visualize_distances.R

Defines functions visualize_distances

Documented in visualize_distances

#' Visualize Distance Matrices via MDS, Heatmap, or Network Graph
#'
#' This function provides a unified interface to visualize distance matrices
#' using classical or weighted Multidimensional Scaling (MDS),
#' heatmaps, or network graphs. Group annotations can be provided for coloring.
#'
#' @param dist_mat A square distance matrix (numeric matrix) or a \code{dist} object.
#' @param method Character string specifying the visualization method. Options are:
#'   \itemize{
#'     \item \code{"mds_classic"}: Classical MDS (cmdscale).
#'     \item \code{"mds_weighted"}: Weighted MDS (wcmdscale, requires \code{weights}).
#'     \item \code{"heatmap"}: Heatmap with optional clustering and group annotations.
#'     \item \code{"qgraph"}: Network graph representation of similarity.
#'   }
#' @param k Integer. Number of dimensions to retain for MDS (default 3). Must be \code{>=1} and \code{<= min(4, n_obs-1)}.
#' @param weights Optional numeric vector of weights for weighted MDS. Must match the number of observations.
#' @param group Optional factor or vector indicating group membership for coloring plots.
#' @param main_title Optional character string specifying the main title of the plot.
#' @param tol Numeric tolerance for checking approximate symmetry (default 1e-10).
#' @param ... Additional arguments passed to internal plotting functions (\code{plot_heatmap} or \code{plot_qgraph}).
#'
#' @details
#' \code{visualize_distances} is a wrapper around three internal plotting functions:
#' \itemize{
#'   \item \code{plot_mds}: Creates a pairwise scatterplot matrix of MDS coordinates with density plots on the diagonal.
#'   \item \code{plot_heatmap}: Plots a heatmap of the distance matrix with hierarchical clustering and optional group annotations.
#'   \item \code{plot_qgraph}: Plots a network graph where nodes represent observations and edges represent similarity.
#' }
#'
#' The function validates that \code{dist_mat} is square, symmetric, and has zero diagonal elements.
#' If a distance matrix has a \code{trimmed_idx} attribute and \code{group} is not provided,
#' a factor indicating "Trimmed" vs "Outlier" is created automatically.
#'
#' @return The plotting object is returned and automatically printed:
#' \itemize{
#'   \item MDS plots return a \code{ggmatrix} from \code{GGally}.
#'   \item Heatmaps return a \code{pheatmap} object.
#'   \item Network graphs are plotted directly (returns \code{NULL}).
#' }
#'
#' @examples
#' # Load iris dataset
#' data(iris)
#'
#' # Compute Euclidean distances on numeric columns
#' dist_iris <- dist(iris[, 1:4])
#'
#' # Create a grouping factor based on Species
#' group_species <- iris$Species
#'
#' # --------------------------------------
#' # Classical MDS (2D)
#' # --------------------------------------
#' visualize_distances(
#'   dist_mat = dist_iris,
#'   method = "mds_classic",
#'   k = 2,
#'   group = group_species,
#'   main_title = "Classical MDS - Iris Dataset - Euclidean Distance"
#' )
#'
#' # --------------------------------------
#' # Weighted MDS (uniform weights)
#' # --------------------------------------
#' weights <- rep(1, nrow(iris))
#' visualize_distances(
#'   dist_mat = dist_iris,
#'   method = "mds_weighted",
#'   k = 2,
#'   weights = weights,
#'   group = group_species,
#'   main_title = "Weighted MDS - Iris Dataset - Euclidean Distance"
#' )
#'
#' # --------------------------------------
#' # Heatmap (limit rows to 30)
#' # --------------------------------------
#' visualize_distances(
#'   dist_mat = dist_iris,
#'   method = "heatmap",
#'   group = group_species,
#'   main_title = "Iris Heatmap by Species - Euclidean Distance",
#'   max_n = 30,
#'   palette = "YlGnBu",
#'   clustering_method = "complete",
#'   annotation_legend = TRUE,
#'   stratified_sampling = TRUE,
#'   seed = 123
#' )
#'
#' # --------------------------------------
#' # Network Graph (limit nodes to 30)
#' # --------------------------------------
#' visualize_distances(
#'   dist_mat = dist_iris,
#'   method = "qgraph",
#'   group = group_species,
#'   max_nodes = 30,
#'   label_size = 2,
#'   edge_threshold = 0.1,
#'   layout = "spring",
#'   seed = 123,
#'   main_title = "Iris Network Graph by Species - Euclidean Distance"
#' )
#'
#' @seealso
#' \code{\link[stats]{cmdscale}} for classical MDS.
#' \code{\link[vegan]{wcmdscale}} for weighted MDS.
#' \code{\link[pheatmap]{pheatmap}} for heatmaps.
#' \code{\link[qgraph]{qgraph}} for network graphs.
#' \code{\link[GGally]{ggpairs}} for MDS scatterplot matrices.
#'
#' @importFrom stats cmdscale dist
#' @importFrom vegan wcmdscale
#' @importFrom pheatmap pheatmap
#' @importFrom qgraph qgraph
#' @importFrom GGally ggpairs
#' @importFrom ggplot2 element_line
#' @importFrom graphics title
#' @importFrom grDevices colorRampPalette
#' @importFrom RColorBrewer brewer.pal
#' @export
visualize_distances <- function(
    dist_mat,
    method = c("mds_classic", "mds_weighted", "heatmap", "qgraph"),
    k = 3,
    weights = NULL,
    group = NULL,
    main_title = NULL,
    tol = 1e-10,
    ...){

  # Match method
  method <- match.arg(method)

  # Check distance matrix
  if (inherits(dist_mat, "dist")) {
    dist_obj <- dist_mat
  } else {
    if (!is.matrix(dist_mat)) stop("dist_mat must be either a 'dist' object or a square symmetric matrix.")
    if (nrow(dist_mat) != ncol(dist_mat)) stop("Matrix must be square.")

    # Approximate symmetry check
    if (any(abs(dist_mat - t(dist_mat)) > tol)) {
      stop("Input matrix is not symmetric (beyond tolerance).")
    }

    # Quietly symmetrize
    dist_mat <- (dist_mat + t(dist_mat)) / 2

    if (any(diag(dist_mat) != 0)) stop("Diagonal elements must be zero.")
    dist_obj <- as.dist(dist_mat)
  }

  n_obs <- attr(dist_obj, "Size")

  # Automatic trimming group
  if (is.null(group) && !is.null(attr(dist_mat, "trimmed_idx"))) {
    trimmed_idx <- attr(dist_mat, "trimmed_idx")
    group <- rep("Outlier", n_obs)
    group[trimmed_idx] <- "Trimmed"
    group <- factor(group, levels = c("Trimmed", "Outlier"))
  }

  # Validate k
  if(!is.numeric(k) || k < 1 || k > min(4, n_obs - 1)) {
    stop(sprintf("k must be an integer between 1 and %d.", min(4, n_obs - 1)))
  }

  # Validate weights
  if (is.null(weights)) {
    weights <- rep(1 / n_obs, n_obs)
  } else {
    if (length(weights) != n_obs) stop("Length of weights must match number of observations.")
    if (any(weights < 0)) stop("Weights must be non-negative.")
    weights <- weights / sum(weights)
  }

  # Validate group
  if(!is.null(group)){
    if(length(group) != n_obs) stop(sprintf("Length of 'group' must equal %d.", n_obs))
  }

  # Check row names
  if (is.null(rownames(as.matrix(dist_obj)))) {
    warning("Distance matrix has no row names. This may affect plotting.")
  }

  # Show Euclidean message only for MDS
  if (method %in% c("mds_classic", "mds_weighted")) {
    message("Note: No internal Euclidean distance transformation is performed. Ensure dist_mat is Euclidean if using MDS.")
  }

  # Main switch
  plot_result <- switch(method,
                        mds_classic = {
                          mds_coords <- stats::cmdscale(dist_obj, k = k, eig = TRUE)
                          coords_df <- as.data.frame(mds_coords$points)
                          colnames(coords_df) <- paste0("PCo", 1:k)
                          coords_df$Group <- if (!is.null(group)) factor(group) else factor("All")
                          print(plot_mds(dist_mat = dist_obj, k = k, group = coords_df$Group, main_title = main_title))
                        },
                        mds_weighted = {
                          if (is.null(weights)) stop("Weights must be provided for weighted MDS.")
                          message("Using weighted MDS (wcmdscale).")
                          mds_coords <- vegan::wcmdscale(dist_obj, w = weights, k = k, eig = TRUE)
                          coords_df <- as.data.frame(mds_coords$points)
                          colnames(coords_df) <- paste0("PCo", 1:k)
                          coords_df$Group <- if (!is.null(group)) factor(group) else factor("All")
                          print(plot_mds(dist_mat = dist_obj, k = k, weights = weights, group = coords_df$Group, main_title = main_title))
                        },
                        heatmap = {
                          plot_heatmap(as.matrix(dist_obj), group = group, main_title = main_title, ...)
                        },
                        qgraph = {
                          plot_qgraph(dist_mat = dist_obj, group = group, main_title = main_title, ...)
                        }
  )

  return(invisible(plot_result))
}

Try the dbrobust package in your browser

Any scripts or data that you put into this service are public.

dbrobust documentation built on Nov. 5, 2025, 6:24 p.m.