R/sankey.R

Defines functions theme_sankey_bump theme_alluvial theme_sankey make_long sigmoid find_default_space prepare_params

Documented in make_long theme_alluvial theme_sankey theme_sankey_bump

# modified from: https://github.com/davidsjoberg/ggsankey/blob/main/R/sankey.R

# ** Support functions ----------
prepare_params <- function(...) {
  # Prepare aesthics for flow lines
  flow.aes <- list(...)
  removes <- names(flow.aes) |>
    stringr::str_extract_all(pattern = "(?<=flow.).*") |>
    unlist()
  removes2 <- names(flow.aes) |>
    stringr::str_subset(pattern = "node") |>
    unlist()
  flow.aes[c(removes, removes2)] <- NULL
  names(flow.aes) <- names(flow.aes) |>
    stringr::str_replace_all(pattern = "flow.", replacement = "")

  # Prepare aesthics for node boxes
  node.aes <- list(...)
  removes <- names(node.aes) |>
    stringr::str_extract_all(pattern = "(?<=node.).*") |>
    unlist()
  removes2 <- names(node.aes) |>
    stringr::str_subset(pattern = "flow") |>
    unlist()
  node.aes[c(removes, removes2)] <- NULL
  names(node.aes) <- names(node.aes) |>
    stringr::str_replace_all(pattern = "node.", replacement = "")

  return(list(flow.aes, node.aes))
}

find_default_space <- function(.df) {
  .df |>
    dplyr::group_by(.data$n_x) |>
    dplyr::summarise(
      n_groups = dplyr::n_distinct(.data$node),
      freq = sum(.data$freq, na.rm = TRUE)
    ) |>
    dplyr::mutate(v = .data$freq / .data$n_groups / 4) |>
    dplyr::pull(.data$v) |>
    max()
}

sigmoid <- function(x_from, x_to, y_from, y_to, smooth = 5, n = 300) {
  x <- seq(-smooth, smooth, length = n)
  y <- exp(x) / (exp(x) + 1)
  out <- data.frame(
    x = (x + smooth) / (smooth * 2) * (x_to - x_from) + x_from,
    y = y * (y_to - y_from) + y_from
  )
}

#' @title Make a long data frame for sankey plot
#'
#' @description Prepares a 'wide' data frame into a format that `geom_sankey` or `geom_alluvial` understands.
#' Useful to show flows between dimensions in dataset.
#'
#' @param .df a data frame
#' @param ... unquoted columnnames of df that you want to include in the plot.
#' @param value if each row have a weight this weight could be kept by providing column name of weight.
#'
#' @return a longer data frame
#'
#' @export
make_long <- function(.df, ..., value = NULL) {
  if ("..r" %in% names(.df)) stop("The column name '..r' is not allowed")
  .vars <- dplyr::quos(...)

  if (!missing(value)) {
    value_var <- dplyr::enquo(value)
    out <- .df |>
      dplyr::select(!!!.vars, value = !!value_var) |>
      dplyr::mutate(..r = dplyr::row_number()) |>
      tidyr::gather(x, node, -..r, -value) |>
      dplyr::arrange(.data$..r) |>
      dplyr::group_by(.data$..r) |>
      dplyr::mutate(
        next_x = dplyr::lead(.data$x),
        next_node = dplyr::lead(.data$node)
      ) |>
      dplyr::ungroup() |>
      dplyr::select(-..r) |>
      dplyr::relocate(value, .after = dplyr::last_col())
  } else {
    out <- .df |>
      dplyr::select(!!!.vars) |>
      dplyr::mutate(..r = dplyr::row_number()) |>
      tidyr::gather(x, node, -..r) |>
      dplyr::arrange(.data$..r) |>
      dplyr::group_by(.data$..r) |>
      dplyr::mutate(
        next_x = dplyr::lead(.data$x),
        next_node = dplyr::lead(.data$node)
      ) |>
      dplyr::ungroup() |>
      dplyr::select(-..r)
  }

  levels <- unique(out$x)

  out |>
    dplyr::mutate(dplyr::across(c(x, next_x), ~ factor(., levels = levels)))
}


#' @title Themes for sankey plot
#'
#' @description Minimal themes for sankey, alluvial and sankey bump plots
#'
#' @param base_size Base font size, given in pts.
#' @param base_family Base font family.
#' @param base_line_size Base size for line elements.
#' @param base_rect_size Base size for rect elements.
#'
#' @name theme_sankey
#' @aliases theme_alluvial
#' @aliases theme_sankey_bump
#'
#' @export
theme_sankey <-
  function(base_size = 11,
           base_family = "",
           base_line_size = base_size / 22,
           base_rect_size = base_size / 22) {{ ggplot2::theme_bw(
    base_size = base_size,
    base_family = base_family,
    base_line_size = base_line_size,
    base_rect_size = base_rect_size
  ) %+replace%
    ggplot2::theme(
      panel.border = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      axis.line = ggplot2::element_line(
        colour = "black",
        linewidth = ggplot2::rel(1)
      ),
      legend.key = ggplot2::element_blank(),
      strip.background = ggplot2::element_rect(
        fill = "white",
        colour = "transparent",
        linewidth = ggplot2::rel(2)
      ),
      complete = TRUE,
      axis.line.y = ggplot2::element_blank(),
      axis.line.x = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      axis.ticks.y = ggplot2::element_blank(),
      axis.ticks.x = ggplot2::element_blank()
    ) }}

#' @rdname theme_sankey
#' @export
theme_alluvial <-
  function(base_size = 11,
           base_family = "",
           base_line_size = base_size / 22,
           base_rect_size = base_size / 22) {{ ggplot2::theme_bw(
    base_size = base_size,
    base_family = base_family,
    base_line_size = base_line_size,
    base_rect_size = base_rect_size
  ) %+replace%
    ggplot2::theme(
      panel.border = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      axis.line = ggplot2::element_blank(),
      legend.key = ggplot2::element_blank(),
      strip.background = ggplot2::element_rect(
        fill = "white",
        colour = "transparent",
        linewidth = ggplot2::rel(2)
      ),
      complete = TRUE,
      axis.line.x = ggplot2::element_blank(),
      axis.ticks.x = ggplot2::element_blank()
    ) }}

#' @rdname theme_sankey
#' @export
theme_sankey_bump <-
  function(base_size = 11,
           base_family = "",
           base_line_size = base_size / 22,
           base_rect_size = base_size / 22) {{ ggplot2::theme_bw(
    base_size = base_size,
    base_family = base_family,
    base_line_size = base_line_size,
    base_rect_size = base_rect_size
  ) %+replace%
    ggplot2::theme(
      panel.border = ggplot2::element_blank(),
      panel.grid.major.x = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      axis.line = ggplot2::element_blank(),
      legend.key = ggplot2::element_blank(),
      strip.background = ggplot2::element_rect(
        fill = "white",
        colour = "transparent",
        linewidth = ggplot2::rel(2)
      ),
      complete = TRUE,
      axis.line.x = ggplot2::element_blank(),
      axis.ticks.x = ggplot2::element_blank(),
      panel.grid.major.y = ggplot2::element_line("gray90")
    ) }}

# FLOW LAYER ---------
StatSankeyFlow <- ggplot2::ggproto("StatSankeyFlow", ggplot2::Stat,
  extra_params = c("n_grid", "na.rm", "type", "width", "space", "smooth"),
  setup_data = function(data, params) {
    purrr::map_dfr(
      unique(data$PANEL),
      ~ {
        data <- data |> dplyr::filter(PANEL == .x)

        data <- data |>
          dplyr::mutate(
            dplyr::across(c(x, next_x), ~ as.numeric(.), .names = ("n_{.col}"))
          )

        if (!("value" %in% names(data))) {
          flow_data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::group_by(n_x, node, n_next_x, next_node) |>
            dplyr::summarise(flow_freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()

          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node, -next_x) |>
            dplyr::group_by_all() |>
            dplyr::summarise(freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()
        } else {
          flow_data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::group_by(n_x, node, n_next_x, next_node) |>
            dplyr::summarise(flow_freq = sum(value, na.rm = TRUE), .groups = "keep") |>
            dplyr::ungroup()

          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node, -next_x) |>
            dplyr::group_by_at(dplyr::vars(dplyr::everything(), -value)) |>
            dplyr::summarise(freq = sum(value, na.rm = TRUE), , .groups = "keep") |>
            dplyr::ungroup()
        }

        if (is.null(params$space)) {
          params$space <- find_default_space(data)
        }

        data <- data |>
          dplyr::group_by(n_x) |>
          dplyr::mutate(
            ymax = cumsum(freq) + (dplyr::row_number() - 1) * params$space,
            ymin = ymax - freq
          ) |>
          dplyr::ungroup()

        if (params$type == "sankey") {
          data <- data |>
            dplyr::group_by(n_x) |>
            dplyr::mutate(
              ymin = ymin - max(ymax) / 2,
              ymax = ymax - max(ymax) / 2
            ) |>
            dplyr::ungroup()
        } else if (params$type == "alluvial") {
          data <- data
        }

        data <- data |>
          dplyr::mutate(
            xmin = n_x - params$width / 2,
            xmax = n_x + params$width / 2
          )

        if ("shift" %in% names(data)) {
          data <- data |>
            dplyr::mutate(dplyr::across(dplyr::contains("y"), ~ . + shift))
        }

        df <- data |>
          dplyr::left_join(flow_data, by = c("n_x", "node"))


        flows <- df |>
          dplyr::left_join(
            df |>
              dplyr::select(
                n_x, node,
                ymin_end = ymin, ymax_end = ymax,
                xmin_end = xmin, xmax_end = xmax
              ) |>
              dplyr::distinct(),
            by = c("n_next_x" = "n_x", "next_node" = "node")
          ) |>
          tidyr::drop_na(
            n_x, node, next_node,
            n_next_x, ymax_end, ymin_end,
            xmax_end, xmin_end
          ) |>
          dplyr::mutate(r = dplyr::row_number()) |>
          dplyr::arrange(n_x, -r) |>
          dplyr::select(-r) |>
          dplyr::group_by(n_x, node) |>
          dplyr::mutate(cum_flow_freq = cumsum(flow_freq) - flow_freq) |>
          dplyr::ungroup() |>
          dplyr::group_by(n_x, n_next_x, node, next_node) |>
          dplyr::mutate(
            flow_start_ymax = ymax - cum_flow_freq,
            flow_start_ymin = flow_start_ymax - flow_freq
          )

        flows <- flows |>
          dplyr::arrange(n_x, n_next_x, next_node) |>
          dplyr::group_by(n_next_x, next_node) |>
          dplyr::mutate(cum_flow_freq_end = cumsum(flow_freq) - flow_freq) |>
          dplyr::mutate(
            flow_end_ymax = ymax_end - cum_flow_freq_end,
            flow_end_ymin = flow_end_ymax - flow_freq
          ) |>
          dplyr::ungroup()

        flows <- flows |>
          dplyr::select(
            -n_x, -node, -freq, -ymax, -ymin,
            -xmin, -n_next_x, -next_node, -flow_freq,
            -ymin_end, -ymax_end, -xmax_end,
            -cum_flow_freq, -cum_flow_freq_end
          ) |>
          dplyr::mutate(group = dplyr::row_number())

        flows |>
          dplyr::mutate(smooth = params$smooth) |>
          as.data.frame()
      }
    )
  },
  compute_group = function(data, scales) {
    out1 <- sigmoid(
      data$xmax, data$xmin_end,
      data$flow_start_ymax, data$flow_end_ymax,
      smooth = data$smooth
    )
    out2 <- sigmoid(
      data$xmin_end, data$xmax,
      data$flow_end_ymin, data$flow_start_ymin,
      smooth = data$smooth
    )
    dplyr::bind_rows(out1, out2)
  }
)

# FLOW SANKEYBUMP LAYER ---------
StatSankeyBumpFlow <- ggplot2::ggproto("StatSankeyBumpFlow", ggplot2::Stat,
  extra_params = c("na.rm", "type", "space", "smooth"),
  setup_data = function(data, params) {
    purrr::map_dfr(
      unique(data$PANEL),
      ~ {
        data <- data |> dplyr::filter(PANEL == .x)

        data <- data |>
          dplyr::mutate(nodes = paste(node, x)) |>
          dplyr::arrange(x, -value) |>
          dplyr::mutate(bbb = dplyr::row_number()) |>
          dplyr::arrange(bbb) |>
          dplyr::mutate(nodes = forcats::fct_reorder(nodes, value, mean)) |>
          dplyr::arrange(node, x) |>
          dplyr::group_by(node) |>
          dplyr::mutate(
            next_x = dplyr::lead(x),
            node = nodes,
            next_node = dplyr::lead(nodes)
          ) |>
          dplyr::ungroup() |>
          dplyr::arrange(x, node)

        data <- data |>
          dplyr::mutate(
            dplyr::across(c(x, next_x), ~ as.numeric(.), .names = ("n_{.col}"))
          )

        if (!("value" %in% names(data))) {
          flow_data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::group_by(n_x, node, n_next_x, next_node) |>
            dplyr::summarise(flow_freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()

          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node) |>
            dplyr::group_by_all() |>
            dplyr::summarise(freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()
        } else {
          flow_data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::group_by(n_x, node, n_next_x, next_node) |>
            dplyr::summarise(
              flow_freq = sum(value, na.rm = TRUE), .groups = "keep"
            ) |>
            dplyr::ungroup()

          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node) |>
            dplyr::group_by_at(dplyr::vars(dplyr::everything(), -value)) |>
            dplyr::summarise(
              freq = sum(value, na.rm = TRUE), .groups = "keep"
            ) |>
            dplyr::ungroup()
        }

        if (is.null(params$space)) {
          params$space <- find_default_space(data)
        }

        data <- data |>
          dplyr::group_by(n_x) |>
          dplyr::arrange(node) |>
          dplyr::mutate(
            ymax = cumsum(freq) + (dplyr::row_number() - 1) * params$space,
            ymin = ymax - freq
          ) |>
          dplyr::ungroup()

        if (params$type == "sankey") {
          data <- data |>
            dplyr::group_by(n_x) |>
            dplyr::mutate(
              ymin = ymin - max(ymax) / 2,
              ymax = ymax - max(ymax) / 2
            ) |>
            dplyr::ungroup()
        } else if (params$type == "alluvial") {
          data <- data
        }

        data <- data |>
          dplyr::mutate(
            xmin = n_x,
            xmax = n_x
          )

        df <- data |>
          dplyr::left_join(flow_data, by = c("n_x", "node"))

        flows <- df |>
          dplyr::left_join(
            df |>
              dplyr::select(
                n_x, node,
                ymin_end = ymin, ymax_end = ymax,
                xmin_end = xmin, xmax_end = xmax,
                flow_freq_end = flow_freq
              ) |>
              dplyr::distinct(),
            by = c("n_next_x" = "n_x", "next_node" = "node")
          ) |>
          tidyr::drop_na(
            n_x, node, next_node,
            n_next_x, ymax_end, ymin_end,
            xmax_end, xmin_end
          ) |>
          dplyr::mutate(r = dplyr::row_number()) |>
          dplyr::arrange(n_x, -r) |>
          dplyr::select(-r) |>
          dplyr::group_by(n_x, node) |>
          dplyr::mutate(cum_flow_freq = cumsum(flow_freq) - flow_freq) |>
          dplyr::ungroup() |>
          dplyr::group_by(n_x, n_next_x, node, next_node) |>
          dplyr::mutate(
            flow_start_ymax = ymax - cum_flow_freq,
            flow_start_ymin = flow_start_ymax - flow_freq
          )

        flows <- flows |>
          dplyr::arrange(n_x, n_next_x, next_node) |>
          dplyr::group_by(n_next_x, next_node) |>
          dplyr::mutate(cum_flow_freq_end = cumsum(flow_freq_end) - flow_freq_end) |>
          dplyr::mutate(
            flow_end_ymax = ymax_end - cum_flow_freq_end,
            flow_end_ymin = flow_end_ymax - flow_freq_end
          ) |>
          dplyr::ungroup()

        flows <- flows |>
          dplyr::select(
            -n_x, -node, -freq, -ymax, -ymin,
            -xmin, -n_next_x, -next_node,
            -flow_freq, -ymin_end, -ymax_end, -xmax_end,
            -cum_flow_freq, -cum_flow_freq_end
          ) |>
          dplyr::mutate(group = dplyr::row_number())

        flows |>
          rowwise() |>
          dplyr::mutate(..groupqq = stringr::str_remove(nodes, as.character(x))) |>
          dplyr::ungroup() |>
          dplyr::group_by(..groupqq) |>
          dplyr::mutate(group = dplyr::cur_group_id()) |>
          dplyr::ungroup() |>
          dplyr::select(-..groupqq) |>
          dplyr::mutate(smooth = params$smooth) |>
          as.data.frame()
      }
    )
  },
  compute_group = function(data, scales) {
    out1 <- purrr::map_dfr(1:nrow(data), ~ {
      datat <- data |> dplyr::slice(.x)
      sigmoid(datat$xmax, datat$xmin_end, datat$flow_start_ymax, datat$flow_end_ymax,
        smooth = datat$smooth
      )
    }) |>
      dplyr::arrange(x)
    out2 <- purrr::map_dfr(1:nrow(data), ~ {
      datat <- data |> dplyr::slice(.x)
      sigmoid(datat$xmin_end, datat$xmax, datat$flow_end_ymin, datat$flow_start_ymin,
        smooth = datat$smooth
      )
    }) |>
      dplyr::arrange(-x)

    dplyr::bind_rows(out1, out2)
  }
)

# TEXT LAYER -------
StatSankeyText <- ggplot2::ggproto("StatSankeyText", ggplot2::Stat,
  extra_params = c("n_grid", "na.rm", "type", "width", "space"),
  setup_data = function(data, params) {
    purrr::map_dfr(
      unique(data$PANEL),
      ~ {
        data <- data |> dplyr::filter(PANEL == .x)

        data <- data |>
          dplyr::mutate(dplyr::across(c(x, next_x), ~ as.numeric(.), .names = ("n_{.col}")))

        if (!("value" %in% names(data))) {
          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node, -next_x) |>
            dplyr::group_by_all() |>
            dplyr::summarise(freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()
        } else {
          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node) |>
            dplyr::group_by_at(dplyr::vars(dplyr::everything(), -value)) |>
            dplyr::summarise(freq = sum(value, na.rm = TRUE), .groups = "keep") |>
            dplyr::ungroup()
        }

        if (is.null(params$space)) {
          params$space <- find_default_space(data)
        }

        data <- data |>
          dplyr::group_by(n_x) |>
          dplyr::mutate(
            ymax = cumsum(freq) + (dplyr::row_number() - 1) * params$space,
            ymin = ymax - freq
          ) |>
          dplyr::ungroup()

        if (params$type == "sankey") {
          data <- data |>
            dplyr::group_by(n_x) |>
            dplyr::mutate(
              ymin = ymin - max(ymax) / 2,
              ymax = ymax - max(ymax) / 2
            ) |>
            dplyr::ungroup()
        } else if (params$type == "alluvial") {
          data <- data
        }

        data <- data |>
          dplyr::mutate(
            xmin = n_x - params$width / 2,
            xmax = n_x + params$width / 2
          )

        data <- data |>
          dplyr::mutate(
            x = n_x,
            y = ymin + (ymax - ymin) / 2
          )

        if ("shift" %in% names(data)) {
          data <- data |>
            dplyr::mutate(dplyr::across(dplyr::contains("y"), ~ . + shift))
        }


        return(as.data.frame(data))
      }
    )
  },
  compute_group = function(data, scales) {
    data
  }
)

# NODE LAYER -------
StatSankeyNode <- ggplot2::ggproto("StatSankeyNode", ggplot2::Stat,
  extra_params = c("n_grid", "na.rm", "type", "width", "space", "smooth"),
  setup_data = function(data, params) {
    purrr::map_dfr(
      unique(data$PANEL),
      ~ {
        data <- data |> dplyr::filter(PANEL == .x)
        data <- data |>
          dplyr::mutate(dplyr::across(c(x, next_x), ~ as.numeric(.), .names = ("n_{.col}")))

        if (!("value" %in% names(data))) {
          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node, -next_x) |>
            dplyr::group_by_all() |>
            dplyr::summarise(freq = dplyr::n(), .groups = "keep") |>
            dplyr::ungroup()
        } else {
          data <- data |>
            dplyr::mutate(group = 1) |>
            dplyr::select(-n_next_x, -next_node, -next_x) |>
            dplyr::group_by_at(dplyr::vars(dplyr::everything(), -value)) |>
            dplyr::summarise(freq = sum(value, na.rm = TRUE), .groups = "keep") |>
            dplyr::ungroup()
        }

        if (is.null(params$space)) {
          params$space <- find_default_space(data)
        }

        data <- data |>
          dplyr::group_by(n_x) |>
          dplyr::mutate(
            ymax = cumsum(freq) + (dplyr::row_number() - 1) * params$space,
            ymin = ymax - freq
          ) |>
          dplyr::ungroup()

        if (params$type == "sankey") {
          data <- data |>
            dplyr::group_by(n_x) |>
            dplyr::mutate(
              ymin = ymin - max(ymax) / 2,
              ymax = ymax - max(ymax) / 2
            ) |>
            dplyr::ungroup()
        } else if (params$type == "alluvial") {
          data <- data
        }

        data <- data |>
          dplyr::mutate(
            xmin = n_x - params$width / 2,
            xmax = n_x + params$width / 2
          )

        if ("shift" %in% names(data)) {
          data <- data |>
            dplyr::mutate(dplyr::across(dplyr::contains("y"), ~ . + shift))
        }

        return(as.data.frame(data))
      }
    )
  },
  compute_group = function(data, scales) {
    data
  }
)

#' @title geom_sankey
#'
#' @description Creates a sankey plot which visualize flows between nodes.
#' Each observation needs to have a `x` aesthetic as well as a `next_x` column which declares where that observation should flow.
#' Also each observation should have a `node` and a `next_node` aesthetic which provide information about which group in the y-direction.
#' By default each row of the data frame is counted to calculate the size of flows. A manual flow value can be added with the `value` aesthetic.
#'
#' @param mapping Provide you own mapping. Both x and y need to be numeric.
#' @param data Provide you own data.
#' @param position Change position.
#' @param na.rm Remove missing values.
#' @param show.legend Show legend in plot.
#' @param space Space between nodes in the y-direction.
#' @param type Either `sankey` or `alluvial`.
#' @param width Width of nodes.
#' @param smooth How much smooth should the curve have? More means steeper curve.
#' @param inherit.aes Should the geom inherit aesthetics.
#' @param ... Other arguments to be passed to the geom.
#'
#' @section Aesthetics:
#' geom_sankey understand the following aesthetics (required aesthetics are in
#' bold):
#'
#' - **x0**
#' - **y0**
#' - **a**
#' - **b**
#' - **angle**
#' - m1
#' - m2
#' - color
#' - fill
#' - size
#' - linetype
#' - alpha
#' - lineend
#'
#' @return A ggplot layer.
#'
#' @export
#' @examples
#' dat <- data.frame(
#'   Group = c("A", "A", "B", "B", "C"),
#'   Type = c("X", "Y", "X", "Y", "X")
#' )
#' long <- make_long(dat, Group, Type)
#' ggplot2::ggplot(
#'   long,
#'   ggplot2::aes(
#'     x = x,
#'     next_x = next_x,
#'     node = node,
#'     next_node = next_node,
#'     fill = node
#'   )
#' ) +
#'   geom_sankey() +
#'   theme_sankey()
geom_sankey <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = NULL,
  type = "sankey",
  width = .1,
  smooth = 8,
  inherit.aes = TRUE,
  ...
) {
  params_list <- prepare_params(...)

  list(
    flow = ggplot2::layer(
      stat = StatSankeyFlow,
      data = data,
      mapping = mapping,
      geom = "polygon",
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      params = purrr::flatten(
        list(
          na.rm = na.rm,
          width = width,
          space = space,
          smooth = smooth,
          type = type,
          params_list[[1]]
        )
      )
    ),
    node = ggplot2::layer(
      stat = StatSankeyNode,
      data = data,
      mapping = mapping,
      geom = ggplot2::GeomRect,
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      params = purrr::flatten(
        list(
          na.rm = na.rm,
          width = width,
          space = space,
          smooth = smooth,
          type = type,
          params_list[[2]]
        )
      )
    )
  )
}

#' @title geom_sankey_label
#'
#' @name geom_sankey_label
#' @aliases geom_sankey_text
#'
#' @description Creates centered labels or text in nodes of your sankey plot. Needs to have the exact same aesthetics as the call to `geom_sankey` to work.
#'
#' @inheritParams geom_sankey
#'
#' @return A ggplot layer.
#'
#' @rdname geom_sankey_label
#' @export
geom_sankey_label <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = NULL,
  type = "sankey",
  width = .1,
  inherit.aes = TRUE,
  ...
) {
  # Prepare aesthetics for label
  label.aes <- list(...)

  list(
    label = ggplot2::layer(
      stat = StatSankeyText,
      data = data,
      mapping = mapping,
      geom = "label",
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      params = purrr::flatten(
        list(
          na.rm = na.rm,
          width = width,
          space = space,
          type = type,
          label.aes
        )
      )
    )
  )
}

#' @rdname geom_sankey_label
#' @export
geom_sankey_text <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = NULL,
  type = "sankey",
  width = .1,
  inherit.aes = TRUE,
  ...
) {
  label.aes <- list(...)

  list(
    label = ggplot2::layer(
      stat = StatSankeyText,
      data = data,
      mapping = mapping,
      geom = "text",
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      params = purrr::flatten(
        list(
          na.rm = na.rm,
          width = width,
          space = space,
          type = type,
          label.aes
        )
      )
    )
  )
}

#' @title geom_alluvial
#'
#' @description Creates an alluvial plot which visualize flows between nodes.
#' Each observation needs to have a `x` aesthetic as well as a `next_x` column which declares where that observation should flow.
#' Also each observation should have a `node` and a `next_node` aesthetic which provide information about which group in the y-direction.
#'
#' @inheritParams geom_sankey
#'
#' @return A ggplot layer.
#'
#' @export
geom_alluvial <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = 0,
  width = .1,
  smooth = 8,
  inherit.aes = TRUE,
  ...
) {
  geom_sankey(
    mapping = mapping,
    data = data,
    position = position,
    na.rm = na.rm,
    show.legend = show.legend,
    space = space,
    width = width,
    smooth = smooth,
    type = "alluvial",
    inherit.aes = inherit.aes,
    ...
  )
}

#' @title geom_alluvial_label
#' @name geom_alluvial_label
#' @aliases geom_alluvial_text
#'
#' @description Creates centered labels or text in nodes of your alluvial plot.
#' Needs to have the exact same aesthetics as the call to `geom_alluvial` to work.
#'
#' @inheritParams geom_sankey
#'
#' @details Other important arguments are;
#' `space` which provides the space between nodes in the y-direction;
#' `shift` which shifts nodes in the y-direction.
#'
#' @return A ggplot layer.
#'
#' @rdname geom_alluvial_label
#' @export
geom_alluvial_text <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = 0,
  width = .1,
  inherit.aes = TRUE,
  ...
) {
  geom_sankey_text(
    mapping = mapping,
    data = data,
    position = position,
    na.rm = na.rm,
    show.legend = show.legend,
    space = space,
    width = width,
    type = "alluvial",
    inherit.aes = inherit.aes,
    ...
  )
}

#' @rdname geom_alluvial_label
#' @export
geom_alluvial_label <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  space = 0,
  width = .1,
  inherit.aes = TRUE,
  ...
) {
  geom_sankey_label(
    mapping = mapping,
    data = data,
    position = position,
    na.rm = na.rm,
    show.legend = show.legend,
    space = space,
    width = width,
    type = "alluvial",
    inherit.aes = inherit.aes,
    ...
  )
}

#' @title geom_sankey_bump
#'
#' @description Creates a sankey bump plot which visualize flows between nodes.
#' Each observation needs to have a `x` aesthetic as well as a `next_x` column which declares where that observation should flow.
#' Also, each observation should have a `node` and a `next_node` aesthetic which provide information about which group in the y-direction.
#'
#' @inheritParams geom_sankey
#'
#' @details Other important arguments are;
#' `space` which provides the space between nodes in the y-direction;
#' `shift` which shifts nodes in the y-direction.
#'
#' @return A ggplot layer.
#'
#' @export
geom_sankey_bump <- function(
  mapping = NULL,
  data = NULL,
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  smooth = 8,
  type = "sankey",
  inherit.aes = TRUE,
  ...
) {
  params_list <- prepare_params(...)

  list(
    flow = ggplot2::layer(
      stat = StatSankeyBumpFlow,
      data = data,
      mapping = mapping,
      geom = "polygon",
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      params = purrr::flatten(list(
        na.rm = na.rm,
        type = type,
        smooth = smooth,
        params_list[[1]]
      ))
    )
  )
}

Try the thisplot package in your browser

Any scripts or data that you put into this service are public.

thisplot documentation built on March 7, 2026, 5:07 p.m.