R/plot_qgraph.R

Defines functions plot_qgraph

Documented in plot_qgraph

#' Plot a Network Graph from a Distance Matrix
#'
#' This internal function visualizes a network graph representation of a distance matrix,
#' where nodes represent observations and edges represent similarity. Groups can be specified
#' for node coloring. A maximum number of nodes can be set to avoid overcrowding, and weak edges
#' are thresholded.
#'
#' @param dist_mat A square distance matrix or a \code{dist} object. Distances are automatically
#'   normalized to [0,1] and converted to similarity via \code{1 - distance}.
#' @param group Optional factor or vector indicating group membership for nodes, used for coloring.
#' @param max_nodes Integer. Maximum number of nodes to plot. If the number of observations exceeds
#'   this, stratified sampling is performed to reduce the node count.
#' @param label_size Numeric. Size of the node labels.
#' @param edge_threshold Numeric between 0 and 1. Edges with similarity below this threshold are removed.
#' @param layout Character string specifying the layout algorithm for \code{qgraph}. Default is \code{"spring"}.
#' @param seed Integer. Random seed used for reproducibility during sampling and layout.
#' @param main_title Optional character string specifying the main title of the plot.
#'
#' @details
#' This function is internal and not intended for direct use. It is called by
#' \code{\link{visualize_distances}()} to display network graphs of robust distances.
#'
#' Features:
#' \itemize{
#'   \item Converts \code{dist} objects to matrices automatically.
#'   \item Downsamples nodes if the number of observations exceeds \code{max_nodes}, using stratified sampling by group.
#'   \item Normalizes the distance matrix to [0,1] and converts it to similarity (1 - distance).
#'   \item Removes weak edges below \code{edge_threshold}.
#'   \item Colors nodes according to group membership.
#'   \item Adds a main title using \code{title()} after plotting with \code{qgraph}.
#' }
#'
#' @return Invisibly returns \code{NULL}. The plot is drawn as a side effect.
#'
#' @examples
#' # --------------------------------------
#' # Network Graph Example from Robust Distances
#' # --------------------------------------
#' data("Data_HC_contamination", package = "dbrobust")
#' # Subset small dataset
#' Data_small <- Data_HC_contamination[1:20, ]
#'
#' cont_vars <- c("V1", "V2", "V3", "V4")
#' cat_vars  <- c("V5", "V6", "V7")
#' bin_vars  <- c("V8", "V9")
#' w <- Data_small$w_loop
#'
#' # Compute GGower robust distances
#' dist_sq_ggower <- dbrobust:::robust_distances(
#'   data = Data_small,
#'   cont_vars = cont_vars,
#'   bin_vars  = bin_vars,
#'   cat_vars  = cat_vars,
#'   w = w,
#'   alpha = 0.10,
#'   method = "ggower"
#' )
#'
#' # Create factor indicating Normal vs Outlier
#' n_obs <- nrow(dist_sq_ggower)
#' group_vec <- rep("Normal", n_obs)
#' group_vec[attr(dist_sq_ggower, "outlier_idx")] <- "Outlier"
#' group_factor <- factor(group_vec, levels = c("Normal", "Outlier"))
#'
#' # Plot network graph (small, for CRAN)
#' dbrobust:::plot_qgraph(
#'   dist_mat = sqrt(dist_sq_ggower),
#'   group = group_factor,
#'   max_nodes = 10,
#'   label_size = 2,
#'   edge_threshold = 0.1,
#'   layout = "spring",
#'   seed = 123,
#'   main_title = "GGower Network Graph with Outliers"
#' )
#'
#' @keywords internal
#' @importFrom qgraph qgraph
#' @importFrom graphics title
plot_qgraph <- function(
    dist_mat,
    group = NULL,
    max_nodes = 100,
    label_size = 2,
    edge_threshold = 0.1,
    layout = "spring",
    seed = 123,
    main_title = NULL
) {
  if (missing(dist_mat)) stop("Argument 'dist_mat' is required.")
  if (!requireNamespace("qgraph", quietly = TRUE)) stop("Package 'qgraph' is required.")

  # Convert to matrix if 'dist' object
  if (inherits(dist_mat, "dist")) dist_mat <- as.matrix(dist_mat)
  if (!is.matrix(dist_mat) || nrow(dist_mat) != ncol(dist_mat)) {
    stop("dist_mat must be a square matrix or a 'dist' object.")
  }

  n <- nrow(dist_mat)
  labels <- rownames(dist_mat)
  if (is.null(labels)) labels <- as.character(seq_len(n))

  # Check group
  if (!is.null(group)) {
    if (length(group) != n) stop("Length of 'group' must match number of observations.")
    if (!is.factor(group)) group <- factor(group)
  } else {
    group <- factor(rep("All", n))
  }
  orig_levels <- levels(group)

  # Downsampling if needed
  if (n > max_nodes) {
    message(sprintf("Too many nodes (%d). Sampling max %d nodes with stratified sampling.", n, max_nodes))
    set.seed(seed)
    sample_idx <- unlist(lapply(split(seq_along(group), group), function(idx) {
      n_group <- length(idx)
      prop <- max(1, round((n_group / n) * max_nodes))  # Ensure at least one per group
      sample(idx, min(length(idx), prop))
    }))
    dist_mat <- dist_mat[sample_idx, sample_idx]
    group <- factor(group[sample_idx], levels = orig_levels)
    labels <- labels[sample_idx]
    n <- length(sample_idx)
  }

  # Normalize distance to [0,1] and convert to similarity
  dist_mat <- dist_mat / max(dist_mat)
  sim_mat <- 1 - dist_mat
  sim_mat[sim_mat < edge_threshold] <- 0

  # Assign colors to groups
  group_levels <- levels(group)
  colors <- get_custom_palette(length(group_levels))
  color_map <- setNames(colors, group_levels)
  node_colors <- color_map[as.character(group)]

  # Plot network graph
  set.seed(seed)
  qgraph::qgraph(
    input = sim_mat,
    layout = layout,
    labels = labels,
    color = node_colors,
    label.cex = label_size,
    edge.color = "darkblue",
    label.scale.equal = TRUE,
    vsize = 5,
    directed = FALSE
  )

  # Add main title
  if (!is.null(main_title)) {
    title(main_title, line = 3, cex.main = 1, font.main = 2)
  }
}

Try the dbrobust package in your browser

Any scripts or data that you put into this service are public.

dbrobust documentation built on Nov. 5, 2025, 6:24 p.m.