R/visualize_gsea.R

Defines functions create_heatmap_plot create_network_plot .build_heatmap_col_fun .build_discrete_fill_for_direction .build_continuous_scale .as_color_vector .is_ggplot_scale create_empty_plot visualize_gsea

Documented in .as_color_vector .build_continuous_scale .build_discrete_fill_for_direction .build_heatmap_col_fun create_empty_plot create_heatmap_plot create_network_plot .is_ggplot_scale visualize_gsea

#' Visualize GSEA results
#'
#' This function creates various visualizations for Gene Set Enrichment Analysis (GSEA) results.
#' It automatically detects whether pathway names are available (from gsea_pathway_annotation())
#' and uses them for better readability, falling back to pathway IDs if names are not available.
#'
#' @param gsea_results A data frame containing GSEA results from the pathway_gsea function
#' @param plot_type A character string specifying the visualization type: "enrichment_plot", "dotplot", "barplot", "network", or "heatmap"
#' @param n_pathways An integer specifying the number of pathways to display
#' @param sort_by A character string specifying the sorting criterion: "NES", "pvalue", or "p.adjust"
#' @param colors A vector of colors for the visualization
#' @param abundance A data frame containing the original abundance data (required for heatmap visualization)
#' @param metadata A data frame containing sample metadata (required for heatmap visualization)
#' @param group A character string specifying the column name in metadata that contains the grouping variable (required for heatmap visualization)
#' @param network_params A list of parameters for network visualization
#' @param heatmap_params A list of parameters for heatmap visualization
#' @param pathway_label_column A character string specifying which column to use for pathway labels.
#'   If NULL (default), the function will automatically use 'pathway_name' if available, otherwise 'pathway_id'.
#'   This allows for custom labeling when using annotated GSEA results.
#' @param scale Optional palette/scale for customizing colors. Accepts: (1) a character vector of colors,
#'   (2) a function that returns colors given an integer (e.g., viridisLite::viridis), or
#'   (3) a ggplot2 scale object (e.g., ggplot2::scale_fill_gradientn(...)).
#'   When NULL, defaults keep current behavior. Applies to: enrichment_plot (fill, continuous),
#'   dotplot (color, continuous), barplot (fill, discrete Positive/Negative), network (color, diverging around 0),
#'   heatmap (main heatmap col; row annotation stays default unless overridden in heatmap_params).

#'
#' @return A ggplot2 object or ComplexHeatmap object
#' @export
#'
#' @examples
#' \dontrun{
#' # Load example data
#' data(ko_abundance)
#' data(metadata)
#'
#' # Prepare abundance data
#' abundance_data <- as.data.frame(ko_abundance)
#' rownames(abundance_data) <- abundance_data[, "#NAME"]
#' abundance_data <- abundance_data[, -1]
#'
#' # Run GSEA analysis (using camera method - recommended)
#' gsea_results <- pathway_gsea(
#'   abundance = abundance_data,
#'   metadata = metadata,
#'   group = "Environment",
#'   pathway_type = "KEGG",
#'   method = "camera"
#' )
#'
#' # Create enrichment plot with pathway IDs (default)
#' visualize_gsea(gsea_results, plot_type = "enrichment_plot", n_pathways = 10)
#'
#' # Annotate results for better pathway names
#' annotated_results <- gsea_pathway_annotation(
#'   gsea_results = gsea_results,
#'   pathway_type = "KEGG"
#' )
#'
#' # Create plots with readable pathway names
#' visualize_gsea(annotated_results, plot_type = "dotplot", n_pathways = 20)
#' visualize_gsea(annotated_results, plot_type = "barplot", n_pathways = 15)
#'
#' # Create network plot with custom labels
#' visualize_gsea(annotated_results, plot_type = "network", n_pathways = 15)
#'
#' # Use custom column for labels (if available)
#' visualize_gsea(annotated_results, plot_type = "barplot",
#'                pathway_label_column = "pathway_name", n_pathways = 10)
#'
#' # Create heatmap
#' visualize_gsea(
#'   annotated_results,
#'   plot_type = "heatmap",
#'   n_pathways = 15,
#'   abundance = abundance_data,
#'   metadata = metadata,
#'   group = "Environment"
#' )
#' }
visualize_gsea <- function(gsea_results,
                          plot_type = "enrichment_plot",
                          n_pathways = 20,
                          sort_by = "p.adjust",
                          colors = NULL,
                          abundance = NULL,
                          metadata = NULL,
                          group = NULL,
                          network_params = list(),
                          heatmap_params = list(),
                          pathway_label_column = NULL,
                          scale = NULL) {

  # Input validation using unified functions
  validate_dataframe(gsea_results, param_name = "gsea_results")
  validate_choice(plot_type, c("enrichment_plot", "dotplot", "barplot", "network", "heatmap"), "plot_type")
  validate_choice(sort_by, c("NES", "pvalue", "p.adjust"), "sort_by")
  validate_dataframe(gsea_results, required_cols = sort_by, param_name = "gsea_results")

  if (!is.null(colors) && !is.character(colors)) {
    stop("colors must be NULL or a character vector")
  }
  if (!is.null(pathway_label_column) && !is.character(pathway_label_column)) {
    stop("pathway_label_column must be NULL or a character string")
  }
  if (!is.numeric(n_pathways) || length(n_pathways) != 1 || is.na(n_pathways)) {
    stop("n_pathways must be a single numeric value")
  }

  # Convert to integer if it's not already
  n_pathways <- as.integer(n_pathways)

  # Note: enrichment_plot / dotplot / barplot are built entirely with
  # ggplot2 (see the branches below); they do not call into enrichplot
  # at runtime, so we don't require that Bioconductor package here.
  # Network and heatmap branches still depend on their own stacks and
  # are checked below.
  if (plot_type == "network") {
    require_package("igraph", "network plots")
    require_package("ggraph", "network plots")
    require_package("tidygraph", "network plots")
    validate_dataframe(
      gsea_results,
      required_cols = c("pathway_id", "NES", "pvalue", "p.adjust", "size", "leading_edge"),
      param_name = "gsea_results"
    )
  }

  if (plot_type == "heatmap") {
    require_package("ComplexHeatmap", "heatmap plots")
    require_package("circlize", "heatmap plots")
    validate_dataframe(
      gsea_results,
      required_cols = c("pathway_id", "NES", "leading_edge"),
      param_name = "gsea_results"
    )

    # Check if required parameters are provided
    if (is.null(abundance) || is.null(metadata) || is.null(group)) {
      stop("For heatmap visualization, 'abundance', 'metadata', and 'group' parameters are required")
    }
  }

  # Set default colors if not provided
  if (is.null(colors)) {
    colors <- c("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999")
  }

  # Determine which column to use for pathway labels
  if (!is.null(pathway_label_column)) {
    # User specified a custom column
    require_column(gsea_results, pathway_label_column, "gsea_results")
    pathway_label_col <- pathway_label_column
  } else {
    # Auto-detect: prefer pathway_name if available, otherwise use pathway_id
    if ("pathway_name" %in% colnames(gsea_results)) {
      pathway_label_col <- "pathway_name"
    } else if ("pathway_id" %in% colnames(gsea_results)) {
      pathway_label_col <- "pathway_id"
    } else {
      # Check if we have any results first
      if (nrow(gsea_results) == 0) {
        return(create_empty_plot(plot_type))
      }
      stop("GSEA results must contain either 'pathway_name' or 'pathway_id' column")
    }
  }

  # Create a standardized pathway_label column for consistent use throughout the function
  gsea_results$pathway_label <- gsea_results[[pathway_label_col]]

  # Sort results based on the specified criterion
  if (sort_by == "NES") {
    gsea_results <- gsea_results[order(abs(gsea_results$NES), decreasing = TRUE), ]
  } else if (sort_by == "pvalue") {
    gsea_results <- gsea_results[order(gsea_results$pvalue), ]
  } else if (sort_by == "p.adjust") {
    gsea_results <- gsea_results[order(gsea_results$p.adjust), ]
  }

  # Handle boundary conditions for n_pathways
  if (n_pathways <= 0) {
    return(create_empty_plot(plot_type))
  }

  # Check if we have any results after filtering
  if (nrow(gsea_results) == 0) {
    return(create_empty_plot(plot_type))
  }

  # Limit to top n_pathways
  if (nrow(gsea_results) > n_pathways) {
    gsea_results <- head(gsea_results, n_pathways)
  }

  # Create visualization based on plot_type
  if (plot_type == "enrichment_plot") {
    # Create enrichment plot
    # For this, we need to convert our results to a format compatible with enrichplot

    # Check if we have the necessary data
    if (!all(c("pathway_id", "NES", "pvalue", "p.adjust") %in% colnames(gsea_results))) {
      stop("GSEA results missing required columns for enrichment plot")
    }

    # Create a simple enrichment plot using ggplot2
    # In a real implementation, we would use enrichplot::gseaplot2
    # But for simplicity, we'll create a basic version

    # Sort by NES
    gsea_results <- gsea_results[order(gsea_results$NES), ]

    # Create a basic barplot of NES values
    p <- ggplot2::ggplot(gsea_results, ggplot2::aes(x = reorder(.data$pathway_label, .data$NES), y = .data$NES, fill = .data$p.adjust)) +
      ggplot2::geom_bar(stat = "identity") +
      ggplot2::coord_flip() +
      # apply user-provided scale if available (fallback to default)
      {
        sc <- .build_continuous_scale(aes = "fill", scale = scale, diverging = FALSE, name = "Adjusted p-value")
        if (is.null(sc)) ggplot2::scale_fill_gradient(low = "red", high = "blue") else sc
      } +
      ggplot2::labs(
        title = "GSEA Enrichment Results",
        x = "Pathway",
        y = "Normalized Enrichment Score (NES)",
        fill = "Adjusted p-value"
      ) +
      ggplot2::theme_minimal() +
      ggplot2::theme(
        axis.text.y = ggplot2::element_text(size = 8),
        plot.title = ggplot2::element_text(hjust = 0.5)
      )

  } else if (plot_type == "dotplot") {
    # Create dotplot
    # Sort by NES
    gsea_results$pathway_label <- factor(gsea_results$pathway_label,
                                      levels = gsea_results$pathway_label[order(gsea_results$NES)])

    p <- ggplot2::ggplot(gsea_results,
                       ggplot2::aes(x = .data$NES, y = .data$pathway_label, color = .data$p.adjust, size = .data$size)) +
      ggplot2::geom_point() +
      # apply user-provided scale if available (fallback to default)
      {
        sc <- .build_continuous_scale(aes = "color", scale = scale, diverging = FALSE, name = "Adjusted p-value")
        if (is.null(sc)) ggplot2::scale_color_gradient(low = "red", high = "blue") else sc
      } +
      ggplot2::labs(
        title = "GSEA Results",
        x = "Normalized Enrichment Score (NES)",
        y = "Pathway",
        color = "Adjusted p-value",
        size = "Gene Set Size"
      ) +
      ggplot2::theme_minimal() +
      ggplot2::theme(
        axis.text.y = ggplot2::element_text(size = 8),
        plot.title = ggplot2::element_text(hjust = 0.5)
      )

  } else if (plot_type == "barplot") {
    # Create barplot
    # Sort by NES
    gsea_results$pathway_label <- factor(gsea_results$pathway_label,
                                      levels = gsea_results$pathway_label[order(gsea_results$NES)])

    # Add color based on NES direction
    gsea_results$direction <- ifelse(gsea_results$NES > 0, "Positive", "Negative")

    p <- ggplot2::ggplot(gsea_results,
                       ggplot2::aes(x = .data$pathway_label, y = .data$NES, fill = .data$direction)) +
      ggplot2::geom_bar(stat = "identity") +
      {
        sc <- .build_discrete_fill_for_direction(scale = scale)
        if (is.null(sc)) ggplot2::scale_fill_manual(values = c("Positive" = "#E41A1C", "Negative" = "#377EB8")) else sc
      } +
      ggplot2::coord_flip() +
      ggplot2::labs(
        title = "GSEA Results",
        x = "Pathway",
        y = "Normalized Enrichment Score (NES)",
        fill = "Direction"
      ) +
      ggplot2::theme_minimal() +
      ggplot2::theme(
        axis.text.y = ggplot2::element_text(size = 8),
        plot.title = ggplot2::element_text(hjust = 0.5)
      )

	  } else if (plot_type == "network") {
    # Set default network parameters
    default_params <- list(
      similarity_measure = "jaccard",
      similarity_cutoff = 0.3,
      layout = "fruchterman",
      node_color_by = "NES",
      edge_width_by = "similarity"
    )

	    # Merge with user-provided parameters
	    network_params <- utils::modifyList(default_params, network_params)
	    validate_choice(network_params$similarity_measure,
	                    c("jaccard", "overlap", "correlation"),
	                    "network_params$similarity_measure")
	    validate_choice(network_params$layout,
	                    c("fruchterman", "kamada", "circle"),
	                    "network_params$layout")
	    validate_choice(network_params$node_color_by,
	                    c("NES", "pvalue", "p.adjust"),
	                    "network_params$node_color_by")
	    validate_choice(network_params$edge_width_by,
	                    c("similarity", "constant"),
	                    "network_params$edge_width_by")
	    if (!is.numeric(network_params$similarity_cutoff) ||
	        length(network_params$similarity_cutoff) != 1 ||
	        is.na(network_params$similarity_cutoff) ||
	        network_params$similarity_cutoff < 0 ||
	        network_params$similarity_cutoff > 1) {
	      stop("network_params$similarity_cutoff must be a single numeric value between 0 and 1",
	           call. = FALSE)
	    }

    # Create network plot
    p <- create_network_plot(
      gsea_results = gsea_results,
      n_pathways = n_pathways,
      similarity_measure = network_params$similarity_measure,
      similarity_cutoff = network_params$similarity_cutoff,
      layout = network_params$layout,
      node_color_by = network_params$node_color_by,
      edge_width_by = network_params$edge_width_by,
      scale = scale
    )

  } else if (plot_type == "heatmap") {
    # Set default heatmap parameters
    default_params <- list(
      cluster_rows = TRUE,
      cluster_columns = TRUE,
      show_rownames = TRUE,
      annotation_colors = list(Group = stats::setNames(colors[seq_along(unique(metadata[[group]]))], unique(metadata[[group]])))
    )

	    # Merge with user-provided parameters
	    heatmap_params <- utils::modifyList(default_params, heatmap_params)
	    if (is.null(heatmap_params$annotation_colors)) {
	      heatmap_params$annotation_colors <- default_params$annotation_colors
	    }

    # Create heatmap
    p <- create_heatmap_plot(
      gsea_results = gsea_results,
      abundance = abundance,
      metadata = metadata,
      group = group,
      n_pathways = n_pathways,
      cluster_rows = heatmap_params$cluster_rows,
      cluster_columns = heatmap_params$cluster_columns,
      show_rownames = heatmap_params$show_rownames,
      annotation_colors = heatmap_params$annotation_colors,
      col_fun = {
        # Prefer explicit col_fun if provided in heatmap_params; else build from `scale`
        if (!is.null(heatmap_params$col_fun)) heatmap_params$col_fun else .build_heatmap_col_fun(scale)
      }
    )
  }

  return(p)
}

#' Create empty plot for edge cases
#'
#' @param plot_type A character string specifying the visualization type
#'
#' @return A ggplot2 object
#' @keywords internal
create_empty_plot <- function(plot_type) {
  # Create consistent empty plots for each visualization type
  base_plot <- ggplot2::ggplot() +
    ggplot2::theme_void() +
    ggplot2::labs(title = paste("No pathways to display for", plot_type))

  if (plot_type %in% c("network", "heatmap")) {
    # Special handling for complex plot types
    base_plot + ggplot2::annotate("text", x = 0, y = 0,
                                  label = "No significant pathways found",
                                  size = 4, color = "gray50")
  } else {
    base_plot + ggplot2::annotate("text", x = 0, y = 0,
                                  label = "No pathways to display",
                                  size = 4, color = "gray50")
  }
}

#' Internal: detect if an object is a ggplot2 Scale
#' @keywords internal
.is_ggplot_scale <- function(x) {
  inherits(x, "Scale")
}

#' Internal: coerce user 'scale' input to a vector of colors (or NULL)
#' Accepts character vector or function(n)->colors. Returns character vector or NULL.
#' @keywords internal
.as_color_vector <- function(scale) {
  if (is.null(scale)) return(NULL)
  if (is.character(scale)) return(scale)
  if (is.function(scale)) {
    # Try to get 100 colors as a reasonable default resolution
    cols <- tryCatch(scale(100), error = function(e) NULL)
    if (is.character(cols) && length(cols) > 1) return(cols)
    return(NULL)
  }
  NULL
}

#' Internal: build a continuous ggplot2 scale layer from colors or ggplot2 scale
#' @keywords internal
.build_continuous_scale <- function(aes = c("fill", "color"), scale = NULL,
                                   diverging = FALSE, midpoint = NULL, name = NULL) {
  aes <- match.arg(aes)
  # If a ggplot2 scale object is provided, return as is
  if (.is_ggplot_scale(scale)) return(scale)
  cols <- .as_color_vector(scale)
  if (is.null(cols)) return(NULL)
  if (diverging && length(cols) >= 3) {
    low <- cols[1]
    mid <- cols[ceiling(length(cols) / 2)]
    high <- cols[length(cols)]
    if (aes == "fill") {
      return(ggplot2::scale_fill_gradient2(low = low, mid = mid, high = high,
                                           midpoint = if (is.null(midpoint)) 0 else midpoint,
                                           name = name))
    } else {
      return(ggplot2::scale_color_gradient2(low = low, mid = mid, high = high,
                                            midpoint = if (is.null(midpoint)) 0 else midpoint,
                                            name = name))
    }
  } else {
    # General gradientn for sequential palettes
    if (aes == "fill") {
      return(ggplot2::scale_fill_gradientn(colors = cols, name = name))
    } else {
      return(ggplot2::scale_color_gradientn(colors = cols, name = name))
    }
  }
}

#' Internal: build a discrete fill scale for barplot direction
#' @keywords internal
.build_discrete_fill_for_direction <- function(scale = NULL) {
  # direction levels are c("Positive", "Negative") in code, but mapping uses both
  cols <- .as_color_vector(scale)
  if (is.null(cols) || length(cols) < 2) return(NULL)
  values <- c("Positive" = cols[length(cols)], "Negative" = cols[1])
  ggplot2::scale_fill_manual(values = values)
}

#' Internal: build a circlize colorRamp2 function for ComplexHeatmap from user scale
#' @keywords internal
.build_heatmap_col_fun <- function(scale = NULL) {
  # Get color vector from scale parameter (handles NULL, vector, function, ggplot scale)
  cols <- .as_color_vector(scale)

  # If no colors provided or couldn't convert, return NULL (use defaults)
  if (is.null(cols)) return(NULL)

  # Create a diverging color function for heatmap
  # Typical heatmap shows z-scores, so we use -2, 0, 2 as breakpoints
  if (length(cols) >= 3) {
    # Use low, mid, high colors from the palette
    low_col <- cols[1]
    mid_col <- cols[ceiling(length(cols) / 2)]
    high_col <- cols[length(cols)]
    return(circlize::colorRamp2(c(-2, 0, 2), c(low_col, mid_col, high_col)))
  } else if (length(cols) == 2) {
    # If only 2 colors, create a gradient without midpoint
    return(circlize::colorRamp2(c(-2, 2), c(cols[1], cols[2])))
  } else {
    # Single color or invalid, return NULL
    return(NULL)
  }
}

#' Create network visualization of GSEA results
#'
#' @param gsea_results A data frame containing GSEA results from the pathway_gsea function
#' @param similarity_measure A character string specifying the similarity measure: "jaccard", "overlap", or "correlation"
#' @param similarity_cutoff A numeric value specifying the similarity threshold for filtering connections
#' @param n_pathways An integer specifying the number of pathways to display
#' @param layout A character string specifying the network layout algorithm: "fruchterman", "kamada", or "circle"
#' @param node_color_by A character string specifying the node color mapping: "NES", "pvalue", or "p.adjust"
#' @param edge_width_by A character string specifying the edge width mapping: "similarity" or "constant"
#' @param scale Optional palette/scale for customizing node color mapping (same conventions as visualize_gsea)
#'
#' @return A ggplot2 object
#' @keywords internal
create_network_plot <- function(gsea_results,
                               similarity_measure = "jaccard",
                               similarity_cutoff = 0.3,
                               n_pathways = 20,
                               layout = "fruchterman",
                               node_color_by = "NES",
                               edge_width_by = "similarity",
                               scale = NULL) {
  # Note: Input validation (packages, n_pathways, nrow) is done in visualize_gsea()

  # Extract leading edge genes
  leading_edges <- strsplit(gsea_results$leading_edge, ";")
  names(leading_edges) <- gsea_results$pathway_id

  # Calculate pathway similarity
  n <- length(leading_edges)
  pathway_ids <- names(leading_edges)
  similarity_matrix <- matrix(0, nrow = n, ncol = n)
  rownames(similarity_matrix) <- pathway_ids
  colnames(similarity_matrix) <- pathway_ids

  for (i in seq_len(n)) {
    for (j in seq_len(n)) {
      if (i != j) {
        set1 <- leading_edges[[i]]
        set2 <- leading_edges[[j]]

        # Handle empty sets
        if (length(set1) == 0 || length(set2) == 0) {
          similarity_matrix[i, j] <- 0
          next
        }

        if (similarity_measure == "jaccard") {
          # Jaccard similarity: |A∩B|/|A∪B|
          similarity_matrix[i, j] <- length(intersect(set1, set2)) / length(union(set1, set2))
        } else if (similarity_measure == "overlap") {
          # Overlap coefficient: |A∩B|/min(|A|,|B|)
          similarity_matrix[i, j] <- length(intersect(set1, set2)) / min(length(set1), length(set2))
        } else if (similarity_measure == "correlation") {
          # Simplified correlation measure
          similarity_matrix[i, j] <- length(intersect(set1, set2)) / sqrt(length(set1) * length(set2))
        }
      }
    }
  }

  # Apply similarity cutoff
  similarity_matrix[similarity_matrix < similarity_cutoff] <- 0

  # Check if there are any connections after applying cutoff
  if (sum(similarity_matrix) == 0) {
    return(create_empty_plot("network"))
  }

  # Create graph object
  graph <- igraph::graph_from_adjacency_matrix(
    similarity_matrix,
    mode = "undirected",
    weighted = TRUE,
    diag = FALSE
  )

  # Check if graph is empty
  if (igraph::vcount(graph) == 0) {
    return(create_empty_plot("network"))
  }

  # Add node attributes
  vertex_attr <- data.frame(
    name = pathway_ids,
    NES = gsea_results$NES[match(pathway_ids, gsea_results$pathway_id)],
    pvalue = gsea_results$pvalue[match(pathway_ids, gsea_results$pathway_id)],
    p.adjust = gsea_results$p.adjust[match(pathway_ids, gsea_results$pathway_id)],
    size = gsea_results$size[match(pathway_ids, gsea_results$pathway_id)],
    pathway_label = gsea_results$pathway_label[match(pathway_ids, gsea_results$pathway_id)],
    stringsAsFactors = FALSE
  )

  # Create a tidygraph object
  tbl_graph <- tidygraph::as_tbl_graph(graph) %>%
    tidygraph::activate("nodes") %>%
    dplyr::mutate(
      name = vertex_attr$name,
      NES = vertex_attr$NES,
      pvalue = vertex_attr$pvalue,
      p.adjust = vertex_attr$p.adjust,
      size = vertex_attr$size,
      pathway_label = vertex_attr$pathway_label
    )

  # Select layout algorithm
  if (layout == "fruchterman") {
    layout_name <- "fr"
  } else if (layout == "kamada") {
    layout_name <- "kk"
  } else if (layout == "circle") {
    layout_name <- "circle"
  } else {
    layout_name <- "fr"
  }

	  edge_layer <- if (edge_width_by == "similarity") {
	    ggraph::geom_edge_link(ggplot2::aes(width = .data$weight, alpha = .data$weight))
	  } else {
	    ggraph::geom_edge_link(ggplot2::aes(alpha = .data$weight), width = 0.5)
	  }
	  edge_width_scale <- if (edge_width_by == "similarity") {
	    ggraph::scale_edge_width(range = c(0.1, 2))
	  } else {
	    NULL
	  }

	  # Create ggraph visualization
	  p <- ggraph::ggraph(tbl_graph, layout = layout_name) +
	    edge_layer +
	    ggraph::geom_node_point(ggplot2::aes(color = .data[[node_color_by]], size = .data$size)) +
	    ggraph::geom_node_text(ggplot2::aes(label = .data$pathway_label), repel = TRUE, size = 3) +
	    edge_width_scale +
    ggraph::scale_edge_alpha(range = c(0.1, 0.8)) +
    # apply user-provided diverging scale for node color if available (fallback to default)
    {
      sc <- .build_continuous_scale(aes = "color", scale = scale, diverging = TRUE, midpoint = 0, name = node_color_by)
      if (is.null(sc)) ggplot2::scale_color_gradient2(low = "blue", mid = "white", high = "red", midpoint = 0, name = node_color_by) else sc
    } +
    ggplot2::scale_size(range = c(2, 8), name = "Gene Set Size") +
    ggraph::theme_graph() +
    ggplot2::labs(
      title = "GSEA Pathway Network",
      subtitle = paste("Similarity measure:", similarity_measure, "| Cutoff:", similarity_cutoff)
    )

  return(p)
}

#' Create heatmap visualization of GSEA results
#'
#' @param gsea_results A data frame containing GSEA results from the pathway_gsea function
#' @param abundance A data frame containing the original abundance data
#' @param metadata A data frame containing sample metadata
#' @param group A character string specifying the column name in metadata that contains the grouping variable
#' @param n_pathways An integer specifying the number of pathways to display
#' @param cluster_rows A logical value indicating whether to cluster rows
#' @param cluster_columns A logical value indicating whether to cluster columns
#' @param show_rownames A logical value indicating whether to show row names
#' @param annotation_colors A list of colors for annotations
#' @param col_fun A color function (e.g., circlize::colorRamp2) to control the main heatmap colors (optional)
#'
#' @return A ComplexHeatmap object
#' @keywords internal
create_heatmap_plot <- function(gsea_results,
                               abundance,
                               metadata,
                               group,
                               n_pathways = 20,
                               cluster_rows = TRUE,
                               cluster_columns = TRUE,
                               show_rownames = TRUE,
                               annotation_colors = NULL,
                               col_fun = NULL) {
  # Note: Input validation (packages, n_pathways, nrow) is done in visualize_gsea()

  # Extract leading edge genes
  leading_edges <- lapply(strsplit(gsea_results$leading_edge, ";"), function(x) x[x != ""])
  names(leading_edges) <- gsea_results$pathway_id

  # Check if any leading edges are empty
  if (all(lengths(leading_edges) == 0)) {
    return(ComplexHeatmap::Heatmap(
      matrix(0, nrow = 1, ncol = 1),
      name = "Empty",
      show_row_names = FALSE,
      show_column_names = FALSE,
      row_title = "No leading edge genes found",
      column_title = "No gene expression data"
    ))
  }

  # Align abundance columns with metadata rows using the package-wide
  # sample-alignment utility. Previously this branch relied solely on
  # `rownames(metadata)` to locate samples, which silently broke on the
  # common case of metadata carrying a `sample_name`/`sample_id` column
  # with default integer rownames -- the column annotation then came
  # out as all-NA, diverging from how every other function in the
  # package matches abundance to metadata.
  aligned <- align_samples(abundance, metadata, verbose = FALSE)
  abundance <- as.matrix(aligned$abundance)
  metadata <- aligned$metadata
  validate_group(metadata, group, min_groups = 1)

  # Create heatmap data matrix
  # For each pathway, calculate the average expression of leading edge genes
  heatmap_data <- matrix(0, nrow = length(leading_edges), ncol = ncol(abundance))
  rownames(heatmap_data) <- names(leading_edges)
  colnames(heatmap_data) <- colnames(abundance)

  for (i in seq_along(leading_edges)) {
    genes <- leading_edges[[i]]

    # Ensure all genes are in abundance data
    genes <- genes[genes %in% rownames(abundance)]

    if (length(genes) > 0) {
      # Calculate average abundance
      heatmap_data[i, ] <- colMeans(abundance[genes, , drop = FALSE])
    }
  }

  # Scale data. Constant rows produce NA after t(scale(t(.))); coerce
  # them to 0 so ComplexHeatmap doesn't crash on clustering.
  heatmap_data_scaled <- t(scale(t(heatmap_data)))
  heatmap_data_scaled[!is.finite(heatmap_data_scaled)] <- 0

  # Column annotation now derives from the aligned metadata, which is
  # guaranteed to match `colnames(heatmap_data)` by construction.
  column_annotation <- metadata[[group]]
  names(column_annotation) <- colnames(heatmap_data)

  # Create column annotation object
  ha <- ComplexHeatmap::HeatmapAnnotation(
    Group = column_annotation,
    col = list(Group = annotation_colors$Group),
    show_legend = TRUE
  )

  # Create row annotation (pathway enrichment scores)
  row_annotation <- data.frame(
    NES = gsea_results$NES,
    row.names = gsea_results$pathway_id
  )

  # Ensure row annotation matches heatmap rows
  row_annotation <- row_annotation[rownames(heatmap_data), , drop = FALSE]

  # Create row annotation object
  # Use symmetric breaks centered at 0 to ensure colorRamp2 gets
  # strictly increasing values even when all NES are same sign
  nes_abs_max <- max(abs(row_annotation$NES), 0.1)
  ra <- ComplexHeatmap::rowAnnotation(
    NES = row_annotation$NES,
    col = list(NES = circlize::colorRamp2(
      c(-nes_abs_max, 0, nes_abs_max),
      c("blue", "white", "red")
    )),
    show_legend = TRUE
  )

  # Create heatmap
  heatmap <- ComplexHeatmap::Heatmap(
    heatmap_data_scaled,
    name = "Z-score",
    col = {
      # Use user-provided col_fun if present; else default blue-white-red
      if (!is.null(col_fun)) col_fun else circlize::colorRamp2(c(-2, 0, 2), c("blue", "white", "red"))
    },
    cluster_rows = cluster_rows,
    cluster_columns = cluster_columns,
    show_row_names = show_rownames,
    row_names_gp = grid::gpar(fontsize = 8),
    top_annotation = ha,
    right_annotation = ra,
    row_title = "Pathways",
    column_title = "Samples",
    row_names_max_width = grid::unit(15, "cm")
  )

  return(heatmap)
}

Try the ggpicrust2 package in your browser

Any scripts or data that you put into this service are public.

ggpicrust2 documentation built on May 20, 2026, 5:07 p.m.