R/parallel_sets.R

Defines functions stat_parallel_sets geom_parallel_sets stat_parallel_sets_axes geom_parallel_sets_axes geom_parallel_sets_labels gather_set_data complete_data sankey_axis_data sankey_diag_data add_y_pos remove_group is.discrete

Documented in gather_set_data geom_parallel_sets geom_parallel_sets geom_parallel_sets_axes geom_parallel_sets_labels stat_parallel_sets stat_parallel_sets_axes

#' Create Parallel Sets diagrams
#'
#' A parallel sets diagram is a type of visualisation showing the interaction
#' between multiple categorical variables. If the variables has an intrinsic
#' order the representation can be thought of as a Sankey Diagram. If each
#' variable is a point in time it will resemble an alluvial diagram.
#'
#' In a parallel sets visualization each categorical variable will be assigned
#' a position on the x-axis. The size of the intersection of categories from
#' neighboring variables are then shown as thick diagonals, scaled by the sum of
#' elements shared between the two categories. The natural data representation
#' for such as plot is to have each categorical variable in a separate column
#' and then have a column giving the amount/magnitude of the combination of
#' levels in the row. This representation is unfortunately not fitting for the
#' `ggplot2` API which needs every position encoding in the same column. To make
#' it easier to work with `ggforce` provides a helper [gather_set_data()], which
#' takes care of the transformation.
#'
#' @section Aesthetics:
#' geom_diagonal_wide understand the following aesthetics
#' (required aesthetics are in bold):
#'
#' - **x**
#' - **id**
#' - **split**
#' - **value**
#' - color
#' - fill
#' - size
#' - linetype
#' - alpha
#' - lineend
#'
#' @inheritParams geom_diagonal_wide
#' @param sep The proportional separation between categories within a variable
#' @param axis.width The width of the area around each variable axis
#' @param angle The angle of the axis label text
#'
#' @name geom_parallel_sets
#' @rdname geom_parallel_sets
#'
#' @author Thomas Lin Pedersen
#'
#' @examples
#' data <- reshape2::melt(Titanic)
#' data <- gather_set_data(data, 1:4)
#'
#' ggplot(data, aes(x, id = id, split = y, value = value)) +
#'   geom_parallel_sets(aes(fill = Sex), alpha = 0.3, axis.width = 0.1) +
#'   geom_parallel_sets_axes(axis.width = 0.1) +
#'   geom_parallel_sets_labels(colour = 'white')
#'
NULL

#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
StatParallelSets <- ggproto('StatParallelSets', Stat,
    setup_data = function(data, params) {
        value_check <- lapply(split(data$value, data$id), unique)
        if (any(lengths(value_check) != 1)) {
            stop('value must be kept constant across id', call. = FALSE)
        }
        data$split <- as.factor(data$split)
        data
    },
    compute_panel = function(data, scales, sep = 0.05, strength = 0.5, n = 100, axis.width = 0) {
        data <- remove_group(data)
        data  <- complete_data(data)
        data_groups <- do.call(rbind, lapply(split(data[, names(data) %in% c('group', 'colour', 'color', 'fill', 'size', 'alpha', 'linetype'), drop = FALSE], data$group), function(d) {
            as.data.frame(lapply(d, function(x) na.omit(x)[1]), stringsAsFactors = FALSE)
        }))
        # Calculate axis sizes
        data_axes <- sankey_axis_data(data, sep)

        # Calculate diagonals
        diagonals <- sankey_diag_data(data, data_axes, data_groups, axis.width)

        StatDiagonalWide$compute_panel(diagonals, scales, strength, n)
    },
    required_aes = c('x', 'id', 'split', 'value'),
    extra_params = c('na.rm', 'n', 'sep', 'strength', 'axis.width')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
stat_parallel_sets <- function(mapping = NULL, data = NULL, geom = "shape",
                               position = "identity", n = 100, strength = 0.5,
                               sep = 0.05, axis.width = 0, na.rm = FALSE,
                               show.legend = NA, inherit.aes = TRUE, ...) {
    layer(
        stat = StatParallelSets, data = data, mapping = mapping, geom = geom,
        position = position, show.legend = show.legend, inherit.aes = inherit.aes,
        params = list(na.rm = na.rm, n = n, strength = strength, sep = sep,
                      axis.width = axis.width, ...)
    )
}
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
geom_parallel_sets <- function(mapping = NULL, data = NULL, stat = "parallel_sets",
                        position = "identity", n = 100, na.rm = FALSE, sep = 0.05,
                        strength = 0.5, axis.width = 0, show.legend = NA,
                        inherit.aes = TRUE, ...) {
    layer(data = data, mapping = mapping, stat = stat, geom = GeomShape,
          position = position, show.legend = show.legend, inherit.aes = inherit.aes,
          params = list(na.rm = na.rm, n = n, strength = strength, sep = sep,
                        axis.width = axis.width, ...))
}
#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
StatParallelSetsAxes <- ggproto('StatParallelSetsAxes', Stat,
    setup_data = function(data, params) {
        value_check <- lapply(split(data$value, data$id), unique)
        if (any(lengths(value_check) != 1)) {
            stop('value must be kept constant across id', call. = FALSE)
        }
        data$split <- as.factor(data$split)
        data
    },
    compute_panel = function(data, scales, sep = 0.05, axis.width = 0) {
        split_levels <- levels(data$split)
        data <- remove_group(data)
        data  <- complete_data(data)
        # Calculate axis sizes
        data_axes <- sankey_axis_data(data, sep)
        data_axes <- data_axes[data_axes$split != '.ggforce_missing', ]
        aes <- data[, names(data) %in% c('x', 'split', 'colour', 'color', 'fill', 'size', 'alpha', 'linetype')]
        aes <- unique(aes)
        if (nrow(aes) != nrow(data_axes)) {
            stop('Axis aesthetics must be constant in each split', call. = FALSE)
        }
        data_axes$split <- factor(as.character(data_axes$split), levels = split_levels)
        aes$split <- factor(as.character(aes$split), levels = split_levels)
        data <- merge(data_axes, aes, by = c('x', 'split'), all.x = TRUE, sort = FALSE)
        names(data)[names(data) == 'split'] <- 'label'
        data$y <- data$ymin + data$value/2
        data$xmin <- data$x - axis.width/2
        data$xmax <- data$x + axis.width/2
        data
    },
    required_aes = c('x', 'id', 'split', 'value'),
    extra_params = c('na.rm', 'sep')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
stat_parallel_sets_axes <- function(mapping = NULL, data = NULL, geom = "parallel_sets_axes",
                               position = "identity", sep = 0.05, axis.width = 0,
                               na.rm = FALSE, show.legend = NA,
                               inherit.aes = TRUE, ...) {
    layer(
        stat = StatParallelSetsAxes, data = data, mapping = mapping, geom = geom,
        position = position, show.legend = show.legend, inherit.aes = inherit.aes,
        params = list(na.rm = na.rm, sep = sep, axis.width = axis.width, ...)
    )
}
#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
GeomParallelSetsAxes <- ggproto('GeomParallelSetsAxes', GeomShape,
    setup_data = function(data, params) {
        data$group <- seq_len(nrow(data))
        lb <- data
        lb$x <- lb$xmin
        lb$y <- lb$ymin
        rb <- data
        rb$x <- rb$xmax
        rb$y <- rb$ymin
        lt <- data
        lt$x <- lt$xmin
        lt$y <- lt$ymax
        rt <- data
        rt$x <- rt$xmax
        rt$y <- rt$ymax
        data <- rbind(lb, rb, rt, lt)
        data[order(data$group), ]
    },
    required_aes = c('xmin', 'ymin', 'xmax', 'ymax')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
geom_parallel_sets_axes <- function(mapping = NULL, data = NULL,
                                    stat = "parallel_sets_axes",
                                    position = "identity", na.rm = FALSE,
                                    show.legend = NA, inherit.aes = TRUE, ...) {
    layer(data = data, mapping = mapping, stat = stat, geom = GeomParallelSetsAxes,
          position = position, show.legend = show.legend, inherit.aes = inherit.aes,
          params = list(na.rm = na.rm, ...))
}
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer GeomText
#' @export
geom_parallel_sets_labels <- function(mapping = NULL, data = NULL,
                                    stat = "parallel_sets_axes", angle = -90,
                                    position = "identity", na.rm = FALSE,
                                    show.legend = NA, inherit.aes = TRUE, ...) {
    layer(data = data, mapping = mapping, stat = stat, geom = GeomText,
          position = position, show.legend = show.legend, inherit.aes = inherit.aes,
          params = list(na.rm = na.rm, angle = angle, ...))
}
#' Tidy data for use with geom_parallel_sets
#'
#' This helper function makes it easy to change tidy data into a tidy(er) format
#' that can be used by geom_parallel_sets.
#'
#' @param data A tidy dataframe with some categorical columns
#' @param x The columns to use for axes in the parallel sets diagram
#' @param id_name The name of the column that will contain the original index of
#' the row.
#'
#' @return A data.frame
#'
#' @export
#'
#' @examples
#' data <- reshape2::melt(Titanic)
#' head(gather_set_data(data, 1:4))
#'
gather_set_data <- function(data, x, id_name = 'id') {
    if (is.numeric(x)) x <- names(data)[x]
    data[[id_name]] <- seq_len(nrow(data))
    do.call(rbind, lapply(x, function(n) {
        data$x <- n
        data$y <- data[[n]]
        data
    }))
}
#' @importFrom stats na.omit
complete_data <- function(data) {
    levels(data$split) <- c(levels(data$split), '.ggforce_missing')
    all_obs <- unique(data[, c('id', 'value')])
    data <- do.call(rbind, lapply(split(data, data$x), function(d) {
        if (anyDuplicated(d$id) != 0) {
            stop('id must be unique within axes', call. = FALSE)
        }
        x <- d$x[1]
        if (length(d$id) != nrow(all_obs)) {
            n_miss <- nrow(all_obs) - length(d$id)
            fill <- d[seq_len(n_miss), ][NA, ]
            fill$x <- x
            fill[, c('id', 'value')] <- all_obs[!d$id %in% all_obs$id, ]
            fill$split <- '.ggforce_missing'
            d <- rbind(d, fill)
        }
        d
    }))

    # Ensure id grouping
    id_groups <- lapply(split(data$group, data$id), function(x) unique(na.omit(x)))
    if (any(lengths(id_groups) != 1)) {
       stop('id must keep grouping across data', call. = FALSE)
    }
    data$group <- unlist(id_groups)[match(as.character(data$id), names(id_groups))]
    data[order(data$x, data$id), ]
}

sankey_axis_data <- function(data, sep) {
    do.call(rbind, lapply(split(data, data$x), function(d) {
        splits <- split(d$value, as.character(d$split))
        d <- data.frame(
            split = names(splits),
            value = sapply(splits, sum),
            x = d$x[1],
            stringsAsFactors = TRUE
        )
        sep <- sum(d$value) * sep
        d$ymax <- (seq_len(nrow(d)) - 1) * sep + cumsum(d$value)
        d$ymin <- d$ymax - d$value
        d
    }))
}

sankey_diag_data <- function(data, axes_data, groups, axis.width) {
    axes <- sort(unique(data$x))
    diagonals <- lapply(seq_len(length(axes) - 1), function(i) {
        from <- data[data$x == axes[i], , drop = FALSE]
        to <- data[data$x == axes[i + 1], , drop = FALSE]
        diagonals <- split(seq_len(nrow(from)), list(from$group, from$split, to$split))
        diagonals <- diagonals[lengths(diagonals) != 0]
        diag_rep <- sapply(diagonals, `[`, 1)
        diag_from <- data.frame(
            group = from$group[diag_rep],
            split = from$split[diag_rep],
            value = sapply(diagonals, function(ii) sum(from$value[ii])),
            x = from$x[1] + axis.width/2,
            stringsAsFactors = FALSE
        )
        diag_to <- diag_from
        diag_to$split <- to$split[diag_rep]
        diag_to$x <- to$x[1] - axis.width/2

        diag_from <- add_y_pos(diag_from, axes_data[axes_data$x == axes[i], ])
        diag_to <- add_y_pos(diag_to, axes_data[axes_data$x == axes[i+1], ])
        diagonals <- rbind(diag_from, diag_to)
        main_groups <- diagonals$group
        diagonals$group <- rep(seq_len(nrow(diag_from)/2), 4)
        if (length(setdiff(names(groups), 'group')) > 0) {
            diagonals <- cbind(diagonals, groups[match(main_groups, groups$group), names(groups) != 'group', drop = FALSE])
        }
        diagonals
    })
    n_groups <- sapply(diagonals, nrow) / 4
    group_offset <- c(0, cumsum(n_groups)[-length(n_groups)])
    do.call(rbind, Map(function(d, i) {d$group <- d$group + i; d}, d = diagonals, i = group_offset))
}

add_y_pos <- function(data, axes_data) {
    splits <- split(seq_len(nrow(data)), as.character(data$split))
    ymax <- lapply(splits, function(i) {
        split <- as.character(data$split[i[1]])
        sizes <- data$value[i]
        ymax <- axes_data$ymin[axes_data$split == split] + cumsum(sizes[order(data$group[i])])
        ymax[order(data$group[i])] <- ymax
        ymax
    })
    data$y[unlist(splits)] <- unlist(ymax)
    data_tmp <- data
    data_tmp$y <- data$y - data$value
    rbind(data_tmp, data)
}

remove_group <- function(data) {
    split_groups <- lapply(split(data$group, data$split), unique)
    if (all(lengths(split_groups) == 1)) {
        data$group <- -1
    } else if (length(Reduce(intersect, split_groups)) == 0) {
        disc <- vapply(data, is.discrete, logical(1))
        disc[names(disc) %in% c('split', 'label', 'PANEL')] <- FALSE
        if (any(disc)) {
            data$group <- plyr::id(data[disc], drop = TRUE)
        } else {
            data$group <- -1
        }
    }
    data
}

is.discrete <- function(x) {
    is.factor(x) || is.character(x) || is.logical(x)
}
YTLogos/ggforce documentation built on May 6, 2019, 4:37 p.m.