#' Create Parallel Sets diagrams
#'
#' A parallel sets diagram is a type of visualisation showing the interaction
#' between multiple categorical variables. If the variables has an intrinsic
#' order the representation can be thought of as a Sankey Diagram. If each
#' variable is a point in time it will resemble an alluvial diagram.
#'
#' In a parallel sets visualization each categorical variable will be assigned
#' a position on the x-axis. The size of the intersection of categories from
#' neighboring variables are then shown as thick diagonals, scaled by the sum of
#' elements shared between the two categories. The natural data representation
#' for such as plot is to have each categorical variable in a separate column
#' and then have a column giving the amount/magnitude of the combination of
#' levels in the row. This representation is unfortunately not fitting for the
#' `ggplot2` API which needs every position encoding in the same column. To make
#' it easier to work with `ggforce` provides a helper [gather_set_data()], which
#' takes care of the transformation.
#'
#' @section Aesthetics:
#' geom_diagonal_wide understand the following aesthetics
#' (required aesthetics are in bold):
#'
#' - **x**
#' - **id**
#' - **split**
#' - **value**
#' - color
#' - fill
#' - size
#' - linetype
#' - alpha
#' - lineend
#'
#' @inheritParams geom_diagonal_wide
#' @param sep The proportional separation between categories within a variable
#' @param axis.width The width of the area around each variable axis
#' @param angle The angle of the axis label text
#'
#' @name geom_parallel_sets
#' @rdname geom_parallel_sets
#'
#' @author Thomas Lin Pedersen
#'
#' @examples
#' data <- reshape2::melt(Titanic)
#' data <- gather_set_data(data, 1:4)
#'
#' ggplot(data, aes(x, id = id, split = y, value = value)) +
#' geom_parallel_sets(aes(fill = Sex), alpha = 0.3, axis.width = 0.1) +
#' geom_parallel_sets_axes(axis.width = 0.1) +
#' geom_parallel_sets_labels(colour = 'white')
#'
NULL
#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
StatParallelSets <- ggproto('StatParallelSets', Stat,
setup_data = function(data, params) {
value_check <- lapply(split(data$value, data$id), unique)
if (any(lengths(value_check) != 1)) {
stop('value must be kept constant across id', call. = FALSE)
}
data$split <- as.factor(data$split)
data
},
compute_panel = function(data, scales, sep = 0.05, strength = 0.5, n = 100, axis.width = 0) {
data <- remove_group(data)
data <- complete_data(data)
data_groups <- do.call(rbind, lapply(split(data[, names(data) %in% c('group', 'colour', 'color', 'fill', 'size', 'alpha', 'linetype'), drop = FALSE], data$group), function(d) {
as.data.frame(lapply(d, function(x) na.omit(x)[1]), stringsAsFactors = FALSE)
}))
# Calculate axis sizes
data_axes <- sankey_axis_data(data, sep)
# Calculate diagonals
diagonals <- sankey_diag_data(data, data_axes, data_groups, axis.width)
StatDiagonalWide$compute_panel(diagonals, scales, strength, n)
},
required_aes = c('x', 'id', 'split', 'value'),
extra_params = c('na.rm', 'n', 'sep', 'strength', 'axis.width')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
stat_parallel_sets <- function(mapping = NULL, data = NULL, geom = "shape",
position = "identity", n = 100, strength = 0.5,
sep = 0.05, axis.width = 0, na.rm = FALSE,
show.legend = NA, inherit.aes = TRUE, ...) {
layer(
stat = StatParallelSets, data = data, mapping = mapping, geom = geom,
position = position, show.legend = show.legend, inherit.aes = inherit.aes,
params = list(na.rm = na.rm, n = n, strength = strength, sep = sep,
axis.width = axis.width, ...)
)
}
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
geom_parallel_sets <- function(mapping = NULL, data = NULL, stat = "parallel_sets",
position = "identity", n = 100, na.rm = FALSE, sep = 0.05,
strength = 0.5, axis.width = 0, show.legend = NA,
inherit.aes = TRUE, ...) {
layer(data = data, mapping = mapping, stat = stat, geom = GeomShape,
position = position, show.legend = show.legend, inherit.aes = inherit.aes,
params = list(na.rm = na.rm, n = n, strength = strength, sep = sep,
axis.width = axis.width, ...))
}
#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
StatParallelSetsAxes <- ggproto('StatParallelSetsAxes', Stat,
setup_data = function(data, params) {
value_check <- lapply(split(data$value, data$id), unique)
if (any(lengths(value_check) != 1)) {
stop('value must be kept constant across id', call. = FALSE)
}
data$split <- as.factor(data$split)
data
},
compute_panel = function(data, scales, sep = 0.05, axis.width = 0) {
split_levels <- levels(data$split)
data <- remove_group(data)
data <- complete_data(data)
# Calculate axis sizes
data_axes <- sankey_axis_data(data, sep)
data_axes <- data_axes[data_axes$split != '.ggforce_missing', ]
aes <- data[, names(data) %in% c('x', 'split', 'colour', 'color', 'fill', 'size', 'alpha', 'linetype')]
aes <- unique(aes)
if (nrow(aes) != nrow(data_axes)) {
stop('Axis aesthetics must be constant in each split', call. = FALSE)
}
data_axes$split <- factor(as.character(data_axes$split), levels = split_levels)
aes$split <- factor(as.character(aes$split), levels = split_levels)
data <- merge(data_axes, aes, by = c('x', 'split'), all.x = TRUE, sort = FALSE)
names(data)[names(data) == 'split'] <- 'label'
data$y <- data$ymin + data$value/2
data$xmin <- data$x - axis.width/2
data$xmax <- data$x + axis.width/2
data
},
required_aes = c('x', 'id', 'split', 'value'),
extra_params = c('na.rm', 'sep')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
stat_parallel_sets_axes <- function(mapping = NULL, data = NULL, geom = "parallel_sets_axes",
position = "identity", sep = 0.05, axis.width = 0,
na.rm = FALSE, show.legend = NA,
inherit.aes = TRUE, ...) {
layer(
stat = StatParallelSetsAxes, data = data, mapping = mapping, geom = geom,
position = position, show.legend = show.legend, inherit.aes = inherit.aes,
params = list(na.rm = na.rm, sep = sep, axis.width = axis.width, ...)
)
}
#' @rdname ggforce-extensions
#' @format NULL
#' @usage NULL
#' @importFrom ggplot2 ggproto Stat
#' @export
GeomParallelSetsAxes <- ggproto('GeomParallelSetsAxes', GeomShape,
setup_data = function(data, params) {
data$group <- seq_len(nrow(data))
lb <- data
lb$x <- lb$xmin
lb$y <- lb$ymin
rb <- data
rb$x <- rb$xmax
rb$y <- rb$ymin
lt <- data
lt$x <- lt$xmin
lt$y <- lt$ymax
rt <- data
rt$x <- rt$xmax
rt$y <- rt$ymax
data <- rbind(lb, rb, rt, lt)
data[order(data$group), ]
},
required_aes = c('xmin', 'ymin', 'xmax', 'ymax')
)
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer
#' @export
geom_parallel_sets_axes <- function(mapping = NULL, data = NULL,
stat = "parallel_sets_axes",
position = "identity", na.rm = FALSE,
show.legend = NA, inherit.aes = TRUE, ...) {
layer(data = data, mapping = mapping, stat = stat, geom = GeomParallelSetsAxes,
position = position, show.legend = show.legend, inherit.aes = inherit.aes,
params = list(na.rm = na.rm, ...))
}
#' @rdname geom_parallel_sets
#' @importFrom ggplot2 layer GeomText
#' @export
geom_parallel_sets_labels <- function(mapping = NULL, data = NULL,
stat = "parallel_sets_axes", angle = -90,
position = "identity", na.rm = FALSE,
show.legend = NA, inherit.aes = TRUE, ...) {
layer(data = data, mapping = mapping, stat = stat, geom = GeomText,
position = position, show.legend = show.legend, inherit.aes = inherit.aes,
params = list(na.rm = na.rm, angle = angle, ...))
}
#' Tidy data for use with geom_parallel_sets
#'
#' This helper function makes it easy to change tidy data into a tidy(er) format
#' that can be used by geom_parallel_sets.
#'
#' @param data A tidy dataframe with some categorical columns
#' @param x The columns to use for axes in the parallel sets diagram
#' @param id_name The name of the column that will contain the original index of
#' the row.
#'
#' @return A data.frame
#'
#' @export
#'
#' @examples
#' data <- reshape2::melt(Titanic)
#' head(gather_set_data(data, 1:4))
#'
gather_set_data <- function(data, x, id_name = 'id') {
if (is.numeric(x)) x <- names(data)[x]
data[[id_name]] <- seq_len(nrow(data))
do.call(rbind, lapply(x, function(n) {
data$x <- n
data$y <- data[[n]]
data
}))
}
#' @importFrom stats na.omit
complete_data <- function(data) {
levels(data$split) <- c(levels(data$split), '.ggforce_missing')
all_obs <- unique(data[, c('id', 'value')])
data <- do.call(rbind, lapply(split(data, data$x), function(d) {
if (anyDuplicated(d$id) != 0) {
stop('id must be unique within axes', call. = FALSE)
}
x <- d$x[1]
if (length(d$id) != nrow(all_obs)) {
n_miss <- nrow(all_obs) - length(d$id)
fill <- d[seq_len(n_miss), ][NA, ]
fill$x <- x
fill[, c('id', 'value')] <- all_obs[!d$id %in% all_obs$id, ]
fill$split <- '.ggforce_missing'
d <- rbind(d, fill)
}
d
}))
# Ensure id grouping
id_groups <- lapply(split(data$group, data$id), function(x) unique(na.omit(x)))
if (any(lengths(id_groups) != 1)) {
stop('id must keep grouping across data', call. = FALSE)
}
data$group <- unlist(id_groups)[match(as.character(data$id), names(id_groups))]
data[order(data$x, data$id), ]
}
sankey_axis_data <- function(data, sep) {
do.call(rbind, lapply(split(data, data$x), function(d) {
splits <- split(d$value, as.character(d$split))
d <- data.frame(
split = names(splits),
value = sapply(splits, sum),
x = d$x[1],
stringsAsFactors = TRUE
)
sep <- sum(d$value) * sep
d$ymax <- (seq_len(nrow(d)) - 1) * sep + cumsum(d$value)
d$ymin <- d$ymax - d$value
d
}))
}
sankey_diag_data <- function(data, axes_data, groups, axis.width) {
axes <- sort(unique(data$x))
diagonals <- lapply(seq_len(length(axes) - 1), function(i) {
from <- data[data$x == axes[i], , drop = FALSE]
to <- data[data$x == axes[i + 1], , drop = FALSE]
diagonals <- split(seq_len(nrow(from)), list(from$group, from$split, to$split))
diagonals <- diagonals[lengths(diagonals) != 0]
diag_rep <- sapply(diagonals, `[`, 1)
diag_from <- data.frame(
group = from$group[diag_rep],
split = from$split[diag_rep],
value = sapply(diagonals, function(ii) sum(from$value[ii])),
x = from$x[1] + axis.width/2,
stringsAsFactors = FALSE
)
diag_to <- diag_from
diag_to$split <- to$split[diag_rep]
diag_to$x <- to$x[1] - axis.width/2
diag_from <- add_y_pos(diag_from, axes_data[axes_data$x == axes[i], ])
diag_to <- add_y_pos(diag_to, axes_data[axes_data$x == axes[i+1], ])
diagonals <- rbind(diag_from, diag_to)
main_groups <- diagonals$group
diagonals$group <- rep(seq_len(nrow(diag_from)/2), 4)
if (length(setdiff(names(groups), 'group')) > 0) {
diagonals <- cbind(diagonals, groups[match(main_groups, groups$group), names(groups) != 'group', drop = FALSE])
}
diagonals
})
n_groups <- sapply(diagonals, nrow) / 4
group_offset <- c(0, cumsum(n_groups)[-length(n_groups)])
do.call(rbind, Map(function(d, i) {d$group <- d$group + i; d}, d = diagonals, i = group_offset))
}
add_y_pos <- function(data, axes_data) {
splits <- split(seq_len(nrow(data)), as.character(data$split))
ymax <- lapply(splits, function(i) {
split <- as.character(data$split[i[1]])
sizes <- data$value[i]
ymax <- axes_data$ymin[axes_data$split == split] + cumsum(sizes[order(data$group[i])])
ymax[order(data$group[i])] <- ymax
ymax
})
data$y[unlist(splits)] <- unlist(ymax)
data_tmp <- data
data_tmp$y <- data$y - data$value
rbind(data_tmp, data)
}
remove_group <- function(data) {
split_groups <- lapply(split(data$group, data$split), unique)
if (all(lengths(split_groups) == 1)) {
data$group <- -1
} else if (length(Reduce(intersect, split_groups)) == 0) {
disc <- vapply(data, is.discrete, logical(1))
disc[names(disc) %in% c('split', 'label', 'PANEL')] <- FALSE
if (any(disc)) {
data$group <- plyr::id(data[disc], drop = TRUE)
} else {
data$group <- -1
}
}
data
}
is.discrete <- function(x) {
is.factor(x) || is.character(x) || is.logical(x)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.