R/simpleUpSet.R

Defines functions .bg_stripes .add_intersections .check_sets .plot_sets .get_set_levels .plot_intersect .plot_grid .add_upper_plots simpleUpSet

Documented in simpleUpSet

#' Make simple UpSet plots
#'
#' Make simple UpSet plots using ggplot2 and patchwork
#'
#' @details
#' Taking a subset of columns from a data.frame, create an UpSet plot showing
#' all intersections as specified.
#' Columns chosen for the sets and intersections must contain logical values
#' or be strictly 0/1 values.
#'
#' Internally, data objects will have the variables `set` and `intersect` which
#' can be referred to when passing custom aes() mappings to various layers.
#' If specifying highlights, the column `highlight` will also be added as a
#' column to the data.frame containing intersections data, following the
#' `case_when` output provided as the argument.
#' Scales can be passed to the intersections and grid panels, taking this
#' structure into account.
#'
#' Any additional layers passed using `annotations()` will have layers added
#' after an initial, internal call to `ggplot(data, aes(x = intersect))`.
#' Additional columns can be used where appropriate for creating boxplots etc.
#'
#' A list of ggplot2 layers, scales, guides and themes is expected in each of
#' the `set_layers`, `intersect_layers` or `grid_layers` arguments, with
#' defaults generated by calls to [default_set_layers()],
#' [default_intersect_layers()] or [default_grid_layers()].
#' These can be used as templates to full customisation by creating a custom
#' list object, or modified directly using the ellipsis
#'
#' @param x Input data frame
#' @param sets Character vector listing columns of x to plot
#' @param sort_sets <[`data-masking`][rlang::args_data_masking]> specification
#' for set order, using variables such as size, desc(size) or NULL. Passed
#' internally to [dplyr::arrange()]. The only possible options are `size`,
#' `desc(size)` or NULL (for sets in the order passed). Can additionally accept
#' the arguments "ascending", "descending" or "none"
#' @param sort_intersect list of <[`data-masking`][rlang::args_data_masking]>
#' specifications for intersection order. Passed internally to
#' [dplyr::arrange()]. The available columns are `size`, `degree` and `set`,
#' along with `highlight` if specified. Any other column names will cause an
#' error. The default order is in descending sizes, using degree and set to
#' break ties.
#' @param n_intersect Maximum number of intersections to show
#' @param min_size Only show intersections larger than this value
#' @param min_degree,max_degree Only show intersections within this range
#' @param set_layers List of `ggplot2` layers, scales and themes to define the
#' appearance of the sets panel. Can be obtained and extended using
#' [default_set_layers()]
#' @param intersect_layers List of `ggplot2` layers, scales and themes to define
#' the appearance of the intersections panel. Can be obtained and extended
#' using [default_intersect_layers()]
#' @param grid_layers List of `ggplot2` layers, scales & themes
#' @param annotations list where each element is a list of ggplot2 layers.
#' Each element will be added as an upper annotation panel above the
#' intersections plot. All layer types (geom, scale, aes, stat, labs etc) can be
#' passed with the exception of facets.
#' @param highlight [case_when()] statement defining all intersections to
#' highlight using `geom_intersect` and `scale_fill/colour_intersect`.
#' Will add a column named `highlight` which can be called from any geom passed
#' to the intersections barplot or matrix
#' @param highlight_levels Given the highlight column will be coerced to a factor
#' when setting colours etc, levels can be manually set here for finer control.
#' @param width,height Proportional width and height of the intersection panel
#' @param stripe_colours Colours for background stripes in the lower two panels.
#' For no stripes, set as NULL
#' @param vjust_ylab Used to nudge the y-axis labels closer to the axis
#' @param guides Passed to [plot_layout()]
#' @param top_left Optional ggplot object to show in the top left panel. Will
#' default to an empty ggplot object
#' @param ... Not used
#' @param na.rm `NA` handling
#'
#' @return Object of class 'patchwork' containing multiple ggplot panels
#'
#' @examples
#' ## Use a modified version of the movies data provided with the package UpSetR
#' library(tidyverse)
#' theme_set(theme_bw())
#' sets <- c("Action", "Comedy", "Drama", "Thriller", "Romance")
#' movies <- system.file("extdata", "movies.tsv.gz", package = "SimpleUpset") %>%
#'   read_tsv() %>%
#'   mutate(
#'     Decade = fct_inorder(Decade) %>% fct_rev()
#'   )
#' simpleUpSet(movies, sets)
#'
#' ## Add a detailed upper plot
#' simpleUpSet(
#'   movies, sets, n_intersect = 10,
#'   annotations = list(
#'     list(
#'       aes(y = AvgRating),
#'       geom_jitter(aes(colour = Decade), height = 0, width = 0.3, alpha = 0.5),
#'       geom_violin(fill = NA, quantiles = 0.5, quantile.linetype = 1),
#'       scale_colour_brewer(palette = "Paired"),
#'       guides(colour = guide_legend(nrow = 2, reverse = TRUE))
#'     )
#'   ), guides = "collect"
#' ) &
#'   theme(legend.position = "bottom")
#'
#' ## Modify set colours
#' set_cols <- c(
#'   Action = "red", Comedy = "grey23", Drama = "red",
#'   Romance = "grey23", Thriller = "grey23"
#' )
#' simpleUpSet(
#'   movies, sets,
#'   set_layers = default_set_layers(
#'     fill = "set", scale_fill_manual(values = set_cols), guides(fill = guide_none())
#'    )
#' )
#'
#' @import patchwork
#' @import ggplot2
#' @importFrom rlang enquo
#' @importFrom dplyr desc
#' @importFrom S7 prop
#' @export
simpleUpSet <- function(
    x,
    sets = NULL,
    sort_sets = size,
    sort_intersect = list(desc(size), degree, set),
    n_intersect = 20, min_size = 0,
    min_degree = 1, max_degree = length(sets),
    set_layers = default_set_layers(),
    intersect_layers = default_intersect_layers(),
    grid_layers = default_grid_layers(),
    highlight = NULL, highlight_levels = NULL,
    annotations = list(),
    width = 0.75, height = 0.75, vjust_ylab = 0.8,
    stripe_colours = c("grey90", "white"),
    guides = "keep", top_left = NULL, ..., na.rm = TRUE
){

  ## Initial checks & argument handling
  stopifnot(all(c(width, height) < 1))

  ## Need to define set levels here for all downstream private funs
  sets <- .check_sets(x, sets, na.rm)
  sets <- .get_set_levels(x, sets, enquo(sort_sets), na.rm)

  ## Get intersections table
  intersect_tbl <- .add_intersections(
    x, sets, substitute(sort_intersect), na.rm, enquo(highlight),
    highlight_levels
  )

  ## Sets panel
  p_sets <- .plot_sets(intersect_tbl, sets, set_layers, stripe_colours)

  ## Intersections panel
  vjust <- max(nchar(sets)) * vjust_ylab ## Place labels closer to y
  p_int <- .plot_intersect(
    intersect_tbl, min_size, n_intersect, min_degree, max_degree,
    intersect_layers, vjust
  )

  ## Intersections matrix
  p_mat <- .plot_grid(p_int, p_sets, grid_layers, stripe_colours)

  ## Blank plot
  p_null <- ggplot() + theme_void() + theme(margins = margin())
  if (is_ggplot(top_left)) p_null <- top_left

  ## Additional annotation figures
  keep_intersect <- droplevels(prop(p_int, "data")$intersect)
  intersect_tbl <- dplyr::filter(intersect_tbl, intersect %in% keep_intersect)
  p_upper <- .add_upper_plots(intersect_tbl, annotations, vjust)
  if (length(p_upper)) {
    p_int <- p_int + theme(plot.margin = margin(0, 5.5, 0, 0))
    p_int <- wrap_plots(c(p_upper, list(p_int)), ncol = 1)
  }

  wrap_plots(c(p_null, p_int, p_sets, p_mat), ncol= 2) +
    plot_layout(
      axes = "collect", guides = guides,
      widths = c(1 - width, width), heights = c(height, 1 - height)
    )

}

#' @import ggplot2
#' @keywords internal
.add_upper_plots <- function(tbl, annotations, vj){

  if (!length(annotations)) return(list())

  ## I can imagine some people wanting a single panel passing a list of layers
  ## Format these into a list
  if (!all(vapply(annotations, is.list, logical(1)))) {
    if (all(vapply(annotations, .check_gg_layers, logical(1)))){
      annotations <- list(annotations)
    }
  }
  ## Check we only have ggplot elements here
  valid <- vapply(
    annotations, \(x) all(vapply(x, .check_gg_layers, logical(1))), logical(1)
  )
  if (!all(valid))
    stop("All elements of annotations must be a list of ggplot2 layers etc")

  ## Make sure degree is discrete if included
  if ("degree" %in% colnames(tbl)) tbl$degree <- as.factor(tbl$degree)

  ## Make a list for panels for wrapping
  lapply(
    annotations, \(x) {
      p <- ggplot(tbl, aes(x = intersect))
      for (i in seq_along(x)) p <- p + x[[i]] ## Add layers sequentially
      p + theme(
        axis.ticks.x.bottom = element_blank(),
        axis.text.x = element_blank(),
        axis.title.x = element_blank(),
        axis.title.y = element_text(vjust = -vj),
        margins = margin(0, 5.5, 5.5, 0)
      )
    }
  )
}

#' @importFrom dplyr distinct summarise anti_join join_by left_join
#' @importFrom rlang !!! syms
#' @importFrom tidyr pivot_longer complete
#' @importFrom tidyselect any_of
#' @importFrom methods is
#' @importFrom S7 prop
#' @import ggplot2
#' @keywords internal
#' .plot_grid(p_int, p_sets, grid_points, grid_layers, stripe_colours)
.plot_grid <- function(p_int, p_sets, layers, stripe_colours){

  ## Check the layers
  if (!is(layers, "default_layers")) {
    is_gg <- vapply(layers, .check_gg_layers, logical(1))
    stopifnot(all(is_gg))
    is_layer <- vapply(layers, is_layer, logical(1))
    stopifnot(any(is_layer))
    ## At least one geom must be a segment
    is_segment <- vapply(layers[is_layer], \(x) is(x$geom, "GeomSegment"), logical(1))
    stopifnot(any(is_segment))
    ## At least one geom must be a point
    is_points <- vapply(layers[is_layer], \(x) is(x$geom, "GeomPoint"), logical(1))
    stopifnot(any(is_points))
  }

  sets <-levels(prop(p_sets, "data")$set)
  ## The grid tbl will contain all intersections
  df <- droplevels(prop(p_int, "data"))
  groups <- intersect(c(sets, "intersect", "highlight", "degree"), colnames(df))
  grid_tbl <- distinct(df, !!!syms(groups))
  grid_tbl <- pivot_longer(
    grid_tbl, all_of(sets), names_to = "set", values_to = "in_group"
  )
  grid_tbl <- grid_tbl[grid_tbl[["in_group"]],]
  grid_tbl$set <- factor(grid_tbl$set, levels = sets)
  grid_tbl$set_int <- as.integer(grid_tbl$set)
  grid_tbl <- left_join(
    grid_tbl,  distinct(df, !!!syms(c(sets, "intersect", "degree"))),
    by = join_by(intersect, degree)
  )

  ## For the segments, we only need the outer intersections on the grid
  seg_tbl <- summarise(
    grid_tbl,
    y_max = max(!!sym("set_int")), y_min = min(!!sym("set_int")),
    .by = any_of(c("intersect", "highlight", "degree", sets))
  )

  ## These layers are more challenging to wrangle using defaults. However
  ## assume if any are GeomSegment the should have data = seg_tbl
  is_segments <- vapply(layers, \(x){
    if (!is(x, "LayerInstance")) return(FALSE)
    is (x$geom, "GeomSegment")
  }, logical(1))
  for (i in which(is_segments)) layers[[i]]$data <- seg_tbl

  is_points <- vapply(layers, \(x){
    if (!is(x, "LayerInstance")) return(FALSE)
    is (x$geom, "GeomPoint")
  }, logical(1))
  if (sum(is_points) > 1) {
    empty_tbl <- complete(grid_tbl, intersect, set)
    empty_tbl <- empty_tbl[is.na(empty_tbl[["in_group"]]),]
    for (i in which(is_points)[-1]) {
      layers[[i]]$data <- empty_tbl
    }
  }

  ## And the main figure
  stripe_geom <- .bg_stripes(sets, stripe_colours)
  p <- ggplot(grid_tbl) + stripe_geom
  for (i in seq_along(layers)) p <- p + layers[[i]]
  p

}

#' @importFrom rlang !! sym
#' @importFrom dplyr summarise
#' @importFrom methods is
#' @import ggplot2
#' @keywords internal
.plot_intersect <- function(
    tbl, min_size, n_intersect, min_degree, max_degree, layers, vj
){

  if (!is(layers, "default_layers")) {
    is_gg <- vapply(layers, .check_gg_layers, logical(1))
    stopifnot(all(is_gg))
    ## At least one geom must be a bar
    is_layer <- vapply(layers, is_layer, logical(1))
    stopifnot(any(is_layer))
    is_bar <- vapply(layers[is_layer], \(x) is(x$geom, "GeomBar"), logical(1))
    stopifnot(any(is_bar))
  }

  n_intersect <- min(nrow(tbl), n_intersect) # Deal with Inf

  ## The totals summarised by intersect (ignoring any fill columns)
  totals_df <- summarise(
    tbl, size = dplyr::n(), .by = any_of(c("intersect", "degree"))
  )
  totals_df <- dplyr::filter(
    totals_df, size > min_size, degree >= min_degree, degree <= max_degree
  )
  totals_df <- dplyr::arrange(totals_df, intersect)
  totals_df <- dplyr::slice(totals_df, seq_len(n_intersect))
  totals_df$intersect <- droplevels(totals_df$intersect)
  totals_df$prop <- totals_df$size / nrow(tbl)
  tbl <- dplyr::filter(tbl, intersect %in% levels(totals_df$intersect))
  if ("degree" %in% colnames(tbl)) tbl$degree <- as.factor(tbl$degree)

  ## Check for labels
  is_labels <- vapply(
    layers, \(x) {
      if (!is(x, "LayerInstance")) return(FALSE)
      is(x$geom, "GeomLabel") | is(x$geom, "GeomText")
    }, logical(1)
  )
  for (i in which(is_labels)) layers[[i]]$data <- totals_df

  ## Add the vjust as a new theme
  n_layers <- length(layers)
  layers[[n_layers + 1]] <- theme(axis.title.y = element_text(vjust = -vj))

  ## The intial plot
  p <- ggplot(tbl)
  for (i in seq_along(layers)) p <- p + layers[[i]]
  p

}

#' @importFrom dplyr summarise across arrange desc
#' @importFrom rlang !! quo_is_null quo_get_expr sym
#' @importFrom tidyselect everything all_of
#' @importFrom tidyr pivot_longer
.get_set_levels <- function(tbl, sets, sort_sets, na.rm) {

  sum_tbl <- summarise(tbl, across(all_of(sets), \(x) sum(x, na.rm = na.rm)))
  sum_tbl <- pivot_longer(sum_tbl, everything(), names_to = "set", values_to = "size")

  sort_quo_expr <- quo_get_expr(sort_sets)
  if (!quo_is_null(sort_sets)) {
    if (is(sort_quo_expr, "name") | is(sort_quo_expr, "call")) {
      ## Handle size/desc(size). Anything else will error
      sum_tbl <- arrange(sum_tbl, !!sort_sets)
    } else {
      ## Now handle ascending/descending as character arguments
      sort_quo_expr <- match.arg(sort_quo_expr, c("ascending", "descending", "none"))
      if (sort_quo_expr == "ascending")  sum_tbl <- arrange(sum_tbl, !!sym("size"))
      if (sort_quo_expr == "descending")  sum_tbl <- arrange(sum_tbl, desc(!!sym("size")))
    }
  }

  sum_tbl$set

}

#' @importFrom dplyr across summarise
#' @importFrom tidyr pivot_longer
#' @importFrom methods is
#' @import ggplot2
#' @keywords internal
.plot_sets <- function(tbl, sets, layers, stripe_colours){

  ## Check the layers
  if (!is(layers, "default_layers")) {
    is_gg <- vapply(layers, .check_gg_layers, logical(1))
    stopifnot(all(is_gg))
    ## At least one geom must be a bar
    is_layer <- vapply(layers, is_layer, logical(1))
    stopifnot(any(is_layer))
    is_bar <- vapply(layers[is_layer], \(x) is(x$geom, "GeomBar"), logical(1))
    stopifnot(any(is_bar))
  }

  ## Sets will contain logical values here
  sets_tbl <- pivot_longer(tbl, all_of(sets), names_to = "set")
  sets_tbl <- sets_tbl[sets_tbl$value,]
  sets_tbl$set <- factor(sets_tbl$set, sets)

  ## Calculate set totals & proportions for optional labels
  col_sums <- colSums(tbl[sets])
  totals_df <- data.frame(
    set = factor(names(col_sums), levels = sets),
    size = col_sums, prop = col_sums / nrow(tbl)
  )
  ## Check for labels
  is_labels <- vapply(
    layers, \(x) {
      if (!is(x, "LayerInstance")) return(FALSE)
      is(x$geom, "GeomLabel") | is(x$geom, "GeomText")
    }, logical(1)
  )

  for (i in which(is_labels)) layers[[i]]$data <- totals_df

  ## The main plot
  stripe_geom <- .bg_stripes(sets, stripe_colours)
  p <- ggplot(sets_tbl) + stripe_geom
  for (i in seq_along(layers)) p <- p + layers[[i]]
  p

}


#' @importFrom dplyr summarise
#' @importFrom tidyselect all_of
#' @keywords internal
.check_sets <- function(x, sets, na.rm){

  ## Get intersections
  if (is.null(sets)) sets <- colnames(x)
  stopifnot(all(sets %in% colnames(x)))

  ## Error if character/factors
  col_types <- vapply(sets, \(i) typeof(x[[i]]), character(1))
  stopifnot(all(!col_types %in% c("character", "factor")))
  ## From here we can assume compatible columns

  ## Check each column only has 0,1 entries
  is_lgl <- vapply(x[sets], \(x) all(x %in% c(0, 1, NA)), logical(1))
  if (!all(is_lgl)) stop(
    "Column(s) ", names(which(!is_lgl)), " not logical or strictly in 0,1"
  )
  ## Check for non-zero values in at least one position
  ## Drop these from the sets going forward
  has_non_zero <- vapply(x[sets], \(x) any(x == 1), logical(1))
  sets[has_non_zero]
}

#' @keywords internal
#' @importFrom dplyr if_any summarise left_join mutate arrange
#' @importFrom tidyr pivot_wider
#' @importFrom rlang quo_is_null !! !!!
#' @importFrom tidyselect all_of
.add_intersections <- function(x, sets, sort_expr, na.rm, hl, hl_levels){

  ## Coerce all set columns to be logical & remove rows where all are FALSE
  x[sets] <- lapply(x[sets], as.logical)
  x <- dplyr::filter(x, if_any(all_of(sets)))
  x <- droplevels(x)
  ## Add highlights if supplied
  if (!quo_is_null(hl)) {
    if (!grepl("^case_when", as_label(hl)))
      stop("highlight can only be specified as a `case_when() statement")
    x <- mutate(x, highlight = !!hl)
    if (is.null(hl_levels)) hl_levels <- sort(unique(x$highlight))
    x$highlight <- droplevels(factor(x$highlight, levels = hl_levels))
  }

  ## Determine the intersect number to a column in the original
  set_tbl <- summarise(
    x, size = dplyr::n(), .by = any_of(c(sets, "highlight"))
  )
  set_tbl$degree <- rowSums(set_tbl[sets])
  set_tbl$temp <- seq_len(nrow(set_tbl))

  ## Now use the data masking approach for arrange
  if (!is.null(sort_expr)) {

    ## Ensure sorting args are passed as a list
    if (!sort_expr[[1]] == quote(list))
      stop("Sorting for intersections must be passed as a list")
    set_tbl <- pivot_longer(set_tbl, all_of(sets), names_to = "set")
    set_tbl$set <- factor(set_tbl$set, levels = sets)
    set_tbl <- arrange(set_tbl, !!!as.list(sort_expr[-1]))
    set_tbl <- set_tbl[set_tbl[["value"]],]
    set_tbl <- pivot_wider(
      set_tbl, names_from = "set", values_from = "value", values_fill = FALSE,
      id_cols = any_of(c("temp", "size", "degree", "highlight"))
    )
  }

  ## Add the intersection id & remove the temp columns
  set_tbl$intersect <- as.factor(seq_len(nrow(set_tbl)))
  set_tbl <- dplyr::select(set_tbl, all_of(c(sets, "intersect", "degree")))

  ## Return the original table with intersect numbers & logical columns
  left_join(x, set_tbl, by = sets)

}

#' @keywords internal
#' @import ggplot2
.bg_stripes <- function(sets, stripe_colours){
  if (is.null(stripe_colours)) return(NULL)
  stripe_tbl <- data.frame(
    set = factor(sets, levels = sets),
    xmin = -Inf, xmax = Inf,
    col = rep_len(stripe_colours, length(sets))
  )
  geom_rect(
    mapping = aes(xmin = !!sym("xmin"), xmax = !!sym("xmax"), y = set),
    data = stripe_tbl, height = 1, fill = stripe_tbl$col, inherit.aes = FALSE
  )
}

#' @keywords internal
#' @import ggplot2
.check_gg_layers <- \(x) {
  any(
    is_theme(x), is_scale(x), is_position(x), is_guides(x), is_layer(x),
    is_mapping(x), is_coord(x), is_stat(x), is_facet(x)
  )
}

## Key options to develop
## Highlighting sets or bars using queries & passing colours to them
##   - Perhaps these could be set directly in a column of actual colours?
## Highlighting points/segments on the grid plot

Try the SimpleUpset package in your browser

Any scripts or data that you put into this service are public.

SimpleUpset documentation built on Nov. 29, 2025, 5:08 p.m.