R/venndiagram.R

Defines functions VennDiagram VennDiagramAtomic prepare_venn_data detect_venn_datatype

Documented in detect_venn_datatype prepare_venn_data VennDiagram VennDiagramAtomic

#' Detect the datatype of the input data of Venn diagram
#'
#' @keywords internal
#' @param data A data frame or a list or a VennPlotData object.
#' @param group_by A character string specifying the column name of the data frame to group the data.
#' @param id_by A character string specifying the column name of the data frame to identify the instances.
#'   Required when `group_by` is a single column and data is a data frame.
#' @return A character string indicating the datatype of the input data or error message if invalid.
#'   Possible values are "long", "wide", "list" and "venn".
#'   "long" indicates the data is in long format.
#'   "wide" indicates the data is in wide format.
#'   "list" indicates the data is a list.
#'   "venn" indicates the data is a VennPlotData object.
detect_venn_datatype <- function(data, group_by = NULL, id_by = NULL) {
    if (inherits(data, "data.frame")) {
        if (length(group_by) < 2 && !is.null(group_by)) {
            return("long")
        } else {  # length(group_by) > 1
            return("wide")
        }
    } else if (inherits(data, "VennPlotData")) {
        return("venn")
    } else if (is.list(data)) {
        return("list")
    }

    stop("Invalid data type. Please provide a data frame, a list or a VennPlotData object generated by prepare_venn_data().")
}

#' Prepare data for Venn diagram
#'
#' @keywords internal
#' @param data A data frame or a list or a VennPlotData object.
#' @param in_form A character string indicating the datatype of the input data.
#'   Possible values are "long", "wide", "list", "venn" or NULL.
#'   "long" indicates the data is in long format.
#'   "wide" indicates the data is in wide format.
#'   "list" indicates the data is a list.
#'   "venn" indicates the data is a VennPlotData object.
#'   "auto" indicates the function will detect the datatype of the input data.
#'
#' A long format data would look like:
#' \preformatted{
#' group_by id_by
#' A        a1
#' A        a2
#' B        a1
#' B        a3
#' ...
#' }
#'
#' A wide format data would look like:
#' \preformatted{
#' A    B
#' TRUE TRUE
#' TRUE FALSE
#' FALSE TRUE
#' ...
#' }
#'
#' A list format data would look like:
#' \preformatted{
#' list(A = c("a1", "a2"), B = c("a1", "a3"))
#' }
#'
#' @param group_by A character string specifying the column name of the data frame to group the data.
#' @param group_by_sep A character string to concatenate the columns in `group_by`,
#'   if multiple columns are provided and the in_form is "long".
#' @param id_by A character string specifying the column name of the data frame to identify the instances.
#'  Required when `group_by` is a single column and data is a data frame.
#' @return A VennPlotData object
prepare_venn_data <- function(data, in_form = "auto", group_by = NULL, group_by_sep = "_", id_by = NULL) {
    # if (!requireNamespace("ggVennDiagram", quietly = TRUE)) {
    #     stop("ggVennDiagram is required for Venn diagram and its data processing.")
    # }
    if (in_form == "auto") {
        in_form <- detect_venn_datatype(data, group_by, id_by)
    }
    if (in_form == "venn") {
        if (!is.null(group_by)) {
            warning("The 'group_by' argument is ignored when the input data is already a 'VennPlotData' object.", immediate. = TRUE)
        }
        return(data)
    }

    if (in_form == "list") {
        if (!is.list(data)) {
            stop("The input data must be a list when the in_form is 'list'.")
        }
        listdata <- data
    } else if (in_form == "long") {
        group_by <- check_columns(data, group_by, force_factor = TRUE, allow_multi = TRUE, concat_multi = TRUE, concat_sep = group_by_sep)
        listdata <- split(data[[id_by]], data[[group_by]])
    } else {  # in_form == "wide"
        group_by <- check_columns(data, group_by, allow_multi = TRUE)
        if (is.null(group_by)) {
            group_by <- colnames(data)
        }
        for (g in group_by) {
            # columns must be logical or 0/1
            if (!is.logical(data[[g]]) && !all(data[[g]] %in% c(0, 1))) {
                stop("The columns in group_by must be logical or 0/1 when the in_form is 'wide'.")
            }
        }
        data$.id <- paste0("id", seq_len(nrow(data)))
        listdata <- lapply(group_by, function(g) data[as.logical(data[[g]]), ".id", drop = TRUE])
        names(listdata) <- group_by
        for (nm in names(listdata)) {
            if (length(listdata[[nm]]) == 0) {
                warning("The group '", nm, "' has no elements, ignored.", immediate. = TRUE)
                listdata[[nm]] <- NULL
            }
        }
    }
    ggVennDiagram::process_data(ggVennDiagram::Venn(listdata))
}

#' Atomic Venn diagram
#'
#' @inheritParams common_args
#' @inheritParams prepare_venn_data
#' @param group_by A character string specifying the column name of the data frame to group the data.
#'  When in_form is "wide", it should be the columns for the groups.
#' @param label A character string specifying the label to show on the Venn diagram.
#'  Possible values are "count", "percent", "both", "none" and a function.
#'  "count" indicates the count of the intersection. "percent" indicates the percentage of the intersection.
#'  "both" indicates both the count and the percentage of the intersection. "none" indicates no label.
#'  If it is a function, if takes a data frame as input and returns a character vector as label.
#'  The data frame has columns "id", "X", "Y", "name", "item" and "count".
#' @param label_fg A character string specifying the color of the label text.
#' @param label_size A numeric value specifying the size of the label text.
#' @param label_bg A character string specifying the background color of the label.
#' @param label_bg_r A numeric value specifying the radius of the background of the label.
#' @param fill_mode A character string specifying the fill mode of the Venn diagram.
#'  Possible values are "count", "set", "count_rev".
#'  "count" indicates the fill color is based on the count of the intersection.
#'  "set" indicates the fill color is based on the set of the intersection.
#'  "count_rev" indicates the fill color is based on the count of the intersection in reverse order.
#'  The palette will be continuous for "count" and "count_rev". The palette will be discrete for "set".
#' @param fill_name A character string to name the legend of colorbar.
#' @return A ggplot object with Venn diagram
#' @keywords internal
#' @importFrom rlang %||%
#' @importFrom ggplot2 geom_polygon geom_path coord_equal aes scale_x_continuous labs
#' @importFrom ggrepel geom_text_repel
VennDiagramAtomic <- function(
    data, in_form = "auto", group_by = NULL, group_by_sep = "_", id_by = NULL,
    label = "count", label_fg = "black", label_size = NULL, label_bg = "white", label_bg_r = 0.1,
    fill_mode = "count", fill_name = NULL,
    palette = ifelse(fill_mode == "set", "Paired", "Spectral"), palcolor = NULL, alpha = 1,
    theme = "theme_this", theme_args = list(), title = NULL, subtitle = NULL,
    legend.position = "right", legend.direction = "vertical", ...
) {
    ggplot <- if (getOption("plotthis.gglogger.enabled", FALSE)) {
        gglogger::ggplot
    } else {
        ggplot2::ggplot
    }
    # if (!requireNamespace("ggVennDiagram", quietly = TRUE)) {
    #     stop("ggVennDiagram (v1.5+) is required for Venn diagram.")
    # }
    base_size <- theme_args$base_size %||% 12
    text_size_scale <- base_size / 12

    s <- data$.split
    data$.split <- NULL
    fill_mode <- match.arg(fill_mode, c("count", "set", "count_rev"))
    data <- prepare_venn_data(data, in_form, group_by, group_by_sep, id_by)
    data_regionedge <- ggVennDiagram::venn_regionedge(data)
    data_setedge <- ggVennDiagram::venn_setedge(data)
    data_regionlabel <- ggVennDiagram::venn_regionlabel(data)
    data_regionlabel$.split <- if (is.null(s)) NULL else s[1]
    data_setlabel <- ggVennDiagram::venn_setlabel(data)
    # Calculate the fill colors for the regions when fill_mode is set
    if (fill_mode == "set") {
        # 1: red 2: blue 3: green 4: purple
        colors <- palette_this(data$setLabel$id, palette = palette, palcolor = palcolor)
        ids <- unique(data_regionedge$id)
        # 1: red 2: blue 3: green 4: purple
        # 1/2: blend(red, blue) 1/3: blend(red, green) 1/4: blend(red, purple)
        # ...
        blended_colors <- sapply(ids, function(id) {
            parts <- strsplit(id, "/")[[1]]
            blend_colors(colors[parts])
        })
        # "1" being transformed to "1.1" after sapply
        names(blended_colors) <- ids
        data_regionedge$fill <- blended_colors[data_regionedge$id]
        data_setedge$color <- colors[data_setedge$id]
    }

    data_setlabel$label <- paste0(data_setlabel$name, "\n(", data_setlabel$count, ")")
    if (identical(label, "percent")) {
        data_regionlabel$label <- scales::percent(
            data_regionlabel$count / sum(data_regionlabel$count))
    } else if (identical(label, "both")) {
        data_regionlabel$label <- paste(
            data_regionlabel$count, "\n",
            scales::percent(data_regionlabel$count / sum(data_regionlabel$count)))
    } else if (identical(label, "count")) {
        data_regionlabel$label <- data_regionlabel$count
    } else if (is.function(label)) {
        data_regionlabel$label <- label(data_regionlabel)
    } else if (!identical(label, "none")) {
        stop("Invalid label argument. Possible values are 'count', 'percent', 'both', 'none' or a function.")
    }

    p <- ggplot()

    if (fill_mode == "set") {
        p <- p +
            geom_polygon(data = data_regionedge, aes(!!sym("X"), !!sym("Y"), group = !!sym("id")),
                         fill = data_regionedge$fill, alpha = alpha) +
            geom_path(data = data_setedge, aes(!!sym("X"), !!sym("Y"), group = !!sym("id")),
                      color = data_setedge$color, show.legend = FALSE)
    } else {
        p <- p +
            geom_polygon(data = data_regionedge, aes(!!sym("X"), !!sym("Y"), fill = !!sym("count"), group = !!sym("id")), alpha = alpha) +
            scale_fill_gradientn(
                n.breaks = 3,
                colors = palette_this(palette = palette, palcolor = palcolor, reverse = grepl("rev", fill_mode)),
                na.value = "grey80",
                guide = guide_colorbar(
                    title = fill_name %||% "",
                    frame.colour = "black", ticks.colour = "black", title.hjust = 0)
            ) +
            geom_path(data = data_setedge, aes(!!sym("X"), !!sym("Y"), color = !!sym("id"), group = !!sym("id")), color = "grey20", show.legend = FALSE)
    }

    if (!identical(label, "none")) {
        p <- p + geom_text_repel(
            data = data_regionlabel, aes(!!sym("X"), !!sym("Y"), label = !!sym("label")),
            color = label_fg, bg.color = label_bg, bg.r = label_bg_r,
            size = label_size %||% text_size_scale * 3.5,
            point.size = NA, max.overlaps = 100, force = 0,
            min.segment.length = 0, segment.colour = "black")
    }

    p <- p + geom_text_repel(
        data = data_setlabel, aes(!!sym("X"), !!sym("Y"), label = !!sym("label")),
        color = label_fg, bg.color = label_bg, bg.r = label_bg_r,
            size = label_size %||% text_size_scale * 4, fontface = "bold",
            point.size = NA, max.overlaps = 100, force = 0,
            min.segment.length = 0, segment.colour = "black") +
        labs(title = title, subtitle = subtitle) +
        coord_equal() +
        do.call(theme, theme_args) +
        ggplot2::theme(
            legend.position = legend.position,
            legend.direction = legend.direction,
            panel.grid.major = element_blank(),
            panel.grid.minor = element_blank(),
            axis.text = element_blank(),
            axis.ticks = element_blank(),
            axis.title = element_blank(),
            panel.border = element_blank()
        )

    maxchars <- max(sapply(data_setlabel$name, nchar))
    p <- p + scale_x_continuous(
        expand = expansion(add = 0.001 * maxchars * text_size_scale, mult = 0.1),
    )

    height <- 5.5
    width <- 6
    if (fill_mode != "set") {
        if (legend.position %in% c("right", "left")) {
            width <- width + 1
        } else if (legend.direction == "horizontal") {
            height <- height + 1
        } else {
            width <- width + 2
        }
    }
    attr(p, "height") <- height
    attr(p, "width") <- width

    return(p)
}


#' Venn diagram
#'
#' @inheritParams common_args
#' @inheritParams VennDiagramAtomic
#' @return A combined ggplot object or wrap_plots object or a list of ggplot objects
#' @export
#' @examples
#' \donttest{
#' set.seed(8525)
#' data = list(
#'     A = sort(sample(letters, 8)),
#'     B = sort(sample(letters, 8)),
#'     C = sort(sample(letters, 8)),
#'     D = sort(sample(letters, 8))
#' )
#'
#' VennDiagram(data)
#' VennDiagram(data, fill_mode = "set")
#' VennDiagram(data, label = "both")
#' # label with a function
#' VennDiagram(data, label = function(df) df$name)
#' VennDiagram(data, palette = "material-indigo", alpha = 0.6)
#' }
VennDiagram <- function(
    data, in_form = c("auto", "long", "wide", "list", "venn"), split_by = NULL, split_by_sep = "_",
    group_by = NULL, group_by_sep = "_", id_by = NULL, label = "count", label_fg = "black",
    label_size = NULL, label_bg = "white", label_bg_r = 0.1, fill_mode = "count", fill_name = NULL,
    palette = ifelse(fill_mode == "set", "Paired", "Spectral"), palcolor = NULL, alpha = 1,
    theme = "theme_this", theme_args = list(), title = NULL, subtitle = NULL,
    legend.position = "right", legend.direction = "vertical",
    combine = TRUE, nrow = NULL, ncol = NULL, byrow = TRUE, seed = 8525,
    axes = NULL, axis_titles = axes, guides = NULL, design = NULL, ...
) {
    validate_common_args(seed)
    in_form <- match.arg(in_form)
    theme <- process_theme(theme)
    if (!is.null(split_by) && !inherits(data, "data.frame")) {
        stop("'split_by' is only available for data frame.")
    }
    split_by <- check_columns(data, split_by, force_factor = TRUE, allow_multi = TRUE, concat_multi = TRUE, concat_sep = split_by_sep)

    if (!is.null(split_by)) {
        datas <- split(data, data[[split_by]])
        # keep the order of levels
        datas <- datas[levels(data[[split_by]])]
    } else {
        datas <- list(data)
        names(datas) <- "..."
    }
    palette <- check_palette(palette, names(datas))
    palcolor <- check_palcolor(palcolor, names(datas))
    legend.direction <- check_legend(legend.direction, names(datas), "legend.direction")
    legend.position <- check_legend(legend.position, names(datas), "legend.position")

    plots <- lapply(
        names(datas), function(nm) {
            default_title <- if (length(datas) == 1 && identical(nm, "...")) NULL else nm
            if (is.function(title)) {
                title <- title(default_title)
            } else {
                title <- title %||% default_title
            }
            if (length(datas) > 1 || !identical(nm, "...")) {
                datas[[nm]]$.split <- nm
            }
            VennDiagramAtomic(datas[[nm]],
                in_form = in_form, group_by = group_by, group_by_sep = group_by_sep, id_by = id_by,
                label = label, label_fg = label_fg, label_size = label_size, label_bg = label_bg, label_bg_r = label_bg_r,
                fill_mode = fill_mode, fill_name = fill_name, palette = palette[[nm]], palcolor = palcolor[[nm]], alpha = alpha,
                theme = theme, theme_args = theme_args, title = title, subtitle = subtitle,
                legend.position = legend.position[[nm]], legend.direction = legend.direction[[nm]], ...
            )
        }
    )

    combine_plots(plots, combine = combine, nrow = nrow, ncol = ncol, byrow = byrow,
        axes = axes, axis_titles = axis_titles, guides = guides, design = design)
}

Try the plotthis package in your browser

Any scripts or data that you put into this service are public.

plotthis documentation built on June 21, 2025, 1:07 a.m.