R/upsetplot.R

Defines functions UpsetPlot UpsetPlotAtomic prepare_upset_data detect_upset_datatype

Documented in detect_upset_datatype prepare_upset_data UpsetPlot UpsetPlotAtomic

#' Detect the type of the input data for Upset plot
#'
#' @keywords internal
#' @param data A data frame or a list
#' @param group_by A character string specifying the column name of the data frame to group the data.
#' @param id_by A character string specifying the column name of the data frame to identify the instances.
#'   Required when `group_by` is a single column and data is a data frame.
#' @return A character string indicating the datatype of the input data or error message if invalid.
#'   Possible values are "long", "wide", "list" and "upset".
#'   "long" indicates the data is in long format.
#'   "wide" indicates the data is in wide format.
#'   "list" indicates the data is a list.
#'   "upset" indicates the data is a UpsetPlotData object.
detect_upset_datatype <- function(data, group_by = NULL, id_by = NULL) {
    if (inherits(data, "UpsetPlotData")) {
        return("upset")
    }

    if (inherits(data, "data.frame")) {
        if (length(group_by) < 2 && !is.null(group_by)) {
            return("long")
        } else { # length(group_by) > 1
            return("wide")
        }
    }

    if (inherits(data, "list")) {
        return("list")
    }

    stop("Invalid data type. Please provide a data frame, a list or an UpsetPlotData object generated by prepare_upset_data().")
}


#' Prepare data for Upset plot
#'
#' @param data A data frame or a list or an UpsetPlotData object.
#' @param in_form A character string indicating the datatype of the input data.
#'   Possible values are "long", "wide", "list", "upset" or "auto".
#'   "long" indicates the data is in long format.
#'   "wide" indicates the data is in wide format.
#'   "list" indicates the data is a list.
#'   "upset" indicates the data is a UpsetPlotData object.
#'   "auto" indicates the function will detect the datatype of the input data.
#'
#' A long format data would look like:
#' \preformatted{
#' group_by id_by
#' A        a1
#' A        a2
#' B        a1
#' B        a3
#' ...
#' }
#'
#' A wide format data would look like:
#' \preformatted{
#' A    B
#' TRUE TRUE
#' TRUE FALSE
#' FALSE TRUE
#' ...
#' }
#'
#' A list format data would look like:
#' \preformatted{
#' list(A = c("a1", "a2"), B = c("a1", "a3"))
#' }
#'
#' An UpsetPlotData object is generated by prepare_update_data() would look like:
#' \preformatted{
#' group_by
#' --------
#' list("A")  # a2
#' list("B")  # a3
#' list(c("A", "B"))  # a1
#' ...
#' }
#'
#' @param group_by A character string specifying the column name of the data frame to group the data.
#' @param group_by_sep A character string to concatenate the columns in `group_by`,
#'   if multiple columns are provided and the in_form is "long".
#' @param id_by A character string specifying the column name of the data frame to identify the instances.
#'  Required when `group_by` is a single column and data is a data frame.
#' @param specific A logical value to show the specific intersections only.
#'  ggVennDiagram, by default, only return the specific subsets of a region. However,
#'  sometimes, we want to show all the overlapping items for two or more sets.
#'  See \url{https://github.com/gaospecial/ggVennDiagram/issues/64} for more details.
#' @return A UpsetPlotData object
#' @keywords internal
#' @importFrom rlang sym
#' @importFrom dplyr distinct %>%
#' @importFrom utils getFromNamespace
#' @importFrom tidyr uncount
prepare_upset_data <- function(data, in_form = "auto", group_by = NULL, group_by_sep = "_", id_by = NULL, specific = TRUE) {
    if (in_form == "auto") {
        in_form <- detect_upset_datatype(data, group_by, id_by)
    }

    if (in_form == "upset") {
        if (!is.null(group_by)) {
            warning("The group_by argument is ignored when the input data is an UpsetPlotData object.", immediate. = TRUE)
        }
        return(data)
    }

    process_upset_data <- getFromNamespace("process_upset_data", "ggVennDiagram")
    if (in_form == "list") {
        listdata <- data
    } else if (in_form == "long") {
        group_by <- check_columns(data, group_by, force_factor = TRUE, allow_multi = TRUE, concat_multi = TRUE, concat_sep = group_by_sep)
        listdata <- split(data[[id_by]], data[[group_by]])
    } else { # in_form == "wide"
        group_by <- check_columns(data, group_by, allow_multi = TRUE)
        if (is.null(group_by)) {
            group_by <- colnames(data)
        }
        for (g in group_by) {
            # columns must be logical or 0/1
            if (!is.logical(data[[g]]) && !all(data[[g]] %in% c(0, 1))) {
                stop("The columns in group_by must be logical or 0/1 when the in_form is 'wide'.")
            }
        }
        data$.id <- paste0("id", seq_len(nrow(data)))
        listdata <- lapply(group_by, function(g) data[as.logical(data[[g]]), ".id", drop = TRUE])
        names(listdata) <- group_by
        for (nm in names(listdata)) {
            if (length(listdata[[nm]]) == 0) {
                warning("The group '", nm, "' has no elements, ignored.", immediate. = TRUE)
                listdata[[nm]] <- NULL
            }
        }
    }

    data <- process_upset_data(ggVennDiagram::Venn(listdata), specific = specific)
    idnames <- paste0(as.character(data$left_data$set), " (", as.character(data$left_data$size), ")")
    names(idnames) <- as.character(seq_along(idnames))

    sep <- ifelse(specific, "/", "~")
    data <- distinct(data$main_data, !!sym("id"), !!sym("size"))

    data$Intersection <- lapply(as.character(data$id), function(x) {
        idnames[strsplit(x, sep, fixed = TRUE)[[1]]]
    })

    uncount(data, !!sym("size"))
}

#' Atomic Upset plot
#'
#' @inheritParams common_args
#' @inheritParams prepare_upset_data
#' @param label A logical value to show the labels on the bars.
#' @param label_fg A character string specifying the color of the label text.
#' @param label_size A numeric value specifying the size of the label text.
#' @param label_bg A character string specifying the background color of the label.
#' @param label_bg_r A numeric value specifying the radius of the background of the label.
#' @param ... Additional arguments passed to [ggupset::scale_x_upset].
#' @return A ggplot object with Upset plot
#' @keywords internal
#' @importFrom rlang %||%
#' @importFrom ggplot2 geom_bar labs guide_colorbar scale_fill_gradientn
#' @importFrom ggrepel geom_text_repel
UpsetPlotAtomic <- function(
    data, in_form = "auto", group_by = NULL, group_by_sep = "_", id_by = NULL,
    label = TRUE, label_fg = "black", label_size = NULL, label_bg = "white", label_bg_r = 0.1,
    palette = "material-indigo", palcolor = NULL, alpha = 1, specific = TRUE,
    theme = "theme_this", theme_args = list(), title = NULL, subtitle = NULL, xlab = NULL, ylab = NULL,
    aspect.ratio = 0.6, legend.position = "right", legend.direction = "vertical", levels = NULL, ...) {
    ggplot <- if (getOption("plotthis.gglogger.enabled", FALSE)) {
        gglogger::ggplot
    } else {
        ggplot2::ggplot
    }

    data <- prepare_upset_data(data, in_form, group_by, group_by_sep, id_by, specific)
    base_size <- theme_args$base_size %||% 12
    text_size_scale <- base_size / 12

    p <- ggplot(data, aes(x = !!sym("Intersection"))) +
        geom_bar(aes(fill = after_stat(!!sym("count"))), alpha = alpha, color = "black", width = 0.5) +
        scale_fill_gradientn(
            n.breaks = 3,
            colors = palette_this(palette = palette, palcolor = palcolor),
            na.value = "grey80",
            guide = guide_colorbar(
                title = "", alpha = alpha,
                frame.colour = "black", ticks.colour = "black", title.hjust = 0
            )
        )
    if (isTRUE(label)) {
        p <- p + geom_text_repel(aes(label = after_stat(!!sym("count"))),
            stat = "count",
            colour = label_fg, size = label_size %||% text_size_scale * 3.5,
            bg.color = label_bg, bg.r = label_bg_r,
            point.size = NA, max.overlaps = 100, force = 0,
            min.segment.length = 0, segment.colour = "black"
        )
    }
    p <- p +
        labs(title = title, subtitle = subtitle, x = xlab %||% "", y = ylab %||% "Intersection size") +
        ggupset::scale_x_upset(...) +
        do.call(theme, theme_args) +
        ggplot2::theme(
            aspect.ratio = aspect.ratio,
            legend.position = legend.position,
            legend.direction = legend.direction,
            panel.grid.major = element_line(colour = "grey80", linetype = 2)
        ) +
        ggupset::theme_combmatrix(
            combmatrix.label.text = element_text(size = 12, color = "black"),
            combmatrix.label.extra_spacing = 6
        )

    upset_args <- list(...)
    n_sets <- upset_args$n_sets %||% 99
    n_sets <- min(n_sets, length(unique(unlist(data$Intersection))))
    n_intersections <- upset_args$n_intersections %||% 99
    n_intersections <- min(n_intersections, length(unique(data$Intersection)))
    maxchars <- max(sapply(unique(unlist(data$Intersection)), nchar))

    height <- 4.5 + n_sets * 0.5
    width <- n_intersections * aspect.ratio + maxchars * 0.05
    if (!identical(legend.position, "none")) {
        if (legend.position %in% c("right", "left")) {
            width <- width + 1
        } else if (legend.direction == "horizontal") {
            height <- height + 1
        } else {
            width <- width + 2
        }
    }

    attr(p, "height") <- height
    attr(p, "width") <- width

    p
}

#' Upset Plot
#'
#' @inheritParams common_args
#' @inheritParams UpsetPlotAtomic
#' @return A ggplot object or wrap_plots object or a list of ggplot objects
#' @export
#' @examples
#' data <- list(
#'     A = 1:5,
#'     B = 2:6,
#'     C = 3:7,
#'     D = 4:8
#' )
#' UpsetPlot(data)
#' UpsetPlot(data, label = FALSE)
#' UpsetPlot(data, palette = "Reds", specific = FALSE)
UpsetPlot <- function(
    data, in_form = c("auto", "long", "wide", "list", "upset"), split_by = NULL, split_by_sep = "_",
    group_by = NULL, group_by_sep = "_", id_by = NULL, label = TRUE, label_fg = "black",
    label_size = NULL, label_bg = "white", label_bg_r = 0.1, palette = "material-indigo", palcolor = NULL,
    alpha = 1, specific = TRUE, theme = "theme_this", theme_args = list(), title = NULL, subtitle = NULL,
    xlab = NULL, ylab = NULL, aspect.ratio = 0.6, legend.position = "right", legend.direction = "vertical",
    combine = TRUE, nrow = NULL, ncol = NULL, byrow = TRUE, seed = 8525,
    axes = NULL, axis_titles = axes, guides = NULL, design = NULL, ...) {
    validate_common_args(seed)
    in_form <- match.arg(in_form)
    theme <- process_theme(theme)
    if (!is.null(split_by) && !inherits(data, "data.frame")) {
        stop("'split_by' is only available for data frame.")
    }
    split_by <- check_columns(data, split_by, force_factor = TRUE, allow_multi = TRUE, concat_multi = TRUE, concat_sep = split_by_sep)

    if (!is.null(split_by)) {
        datas <- split(data, data[[split_by]])
        # keep the order of levels
        datas <- datas[levels(data[[split_by]])]
    } else {
        datas <- list(data)
        names(datas) <- "..."
    }
    palette <- check_palette(palette, names(datas))
    palcolor <- check_palcolor(palcolor, names(datas))
    legend.direction <- check_legend(legend.direction, names(datas), "legend.direction")
    legend.position <- check_legend(legend.position, names(datas), "legend.position")

    plots <- lapply(
        names(datas), function(nm) {
            default_title <- if (length(datas) == 1 && identical(nm, "...")) NULL else nm
            if (is.function(title)) {
                title <- title(default_title)
            } else {
                title <- title %||% default_title
            }
            if (is.null(group_by) && (in_form %in% c("auto", "wide"))) {
                group_by <- setdiff(colnames(datas[[nm]]), c(id_by, split_by))
            }
            UpsetPlotAtomic(datas[[nm]],
                in_form = in_form, group_by = group_by, group_by_sep = group_by_sep, id_by = id_by,
                label = label, label_fg = label_fg, label_size = label_size, label_bg = label_bg, label_bg_r = label_bg_r,
                palette = palette[[nm]], palcolor = palcolor[[nm]], alpha = alpha, specific = specific,
                theme = theme, theme_args = theme_args, title = title, subtitle = subtitle, xlab = xlab, ylab = ylab,
                aspect.ratio = aspect.ratio, legend.position = legend.position[[nm]], legend.direction = legend.direction[[nm]],
                ...
            )
        }
    )

    combine_plots(plots, combine = combine, nrow = nrow, ncol = ncol, byrow = byrow,
        axes = axes, axis_titles = axis_titles, guides = guides, design = design)
}

Try the plotthis package in your browser

Any scripts or data that you put into this service are public.

plotthis documentation built on June 8, 2025, 11:11 a.m.