R/dataviz.R

Defines functions sankey_ly

Documented in sankey_ly

#' Sankey Plot with Plotly
#' @details A customized function for data transformation and plotting sankey plot with Plotly
#' @export
#' @param x A data.frame input, must have at least two categorical columns and one numeric column
#' @param cat_cols A  vector of at least two categorical columns names
#' @param num_col A single numeric column name
#' @param title Optional, string to pass to plotly layout title function
#' @examples
#' data("sfo_passengers")
#'
#' library(dplyr)
#'
#' d <- sfo_passengers %>%
#'   filter(activity_period >= 202201 & activity_period < 202301)
#'
#' head(d)
#'
#' d %>%
#'   filter(operating_airline == "United Airlines") %>%
#'   mutate(terminal = ifelse(terminal == "International", "international", terminal)) %>%
#'   group_by(operating_airline,activity_type_code, geo_summary, geo_region,  terminal) %>%
#'   summarise(total = sum(passenger_count), .groups = "drop") %>%
#'   sankey_ly(cat_cols = c("operating_airline", "terminal","geo_summary",
#'                          "geo_region", "activity_type_code"),
#'             num_col = "total",
#'             title = "Distribution of United Airlines Passengers at SFO During 2022")


sankey_ly <- function(x, cat_cols, num_col, title = NULL){
  `%>%` <- magrittr::`%>%`
  index <- NULL

  # Error handeling
  if(!is.data.frame(x)){
    stop("The input object is not a valid data.frame")
  } else if(!all(c(cat_cols, num_col) %in% names(x))){
    stop("One or more column names does not matching to the names of the input object")
  } else if(length(num_col) != 1){
    stop("Cannot define more than one numeric column with the num_col argument")
  } else if(length(cat_cols) < 2){
    stop("Must define at least two categorical columns with the cat_cols argument")
  } else if(!is.null(title) && !is.character(title)){
    stop("The title argument is not valid character object")
  }


  map <- function(x, cat_cols){
    unique_cat <- map_df <- NULL
    x <- as.data.frame(x)
    for(i in cat_cols){
      unique_cat <- c(unique_cat, base::unique(x[, i]))
    }

    map_df <- base::data.frame(cat = unique_cat,
                               index = 0:(length(unique_cat) - 1),
                               stringsAsFactors = FALSE)
    return(map_df)
  }

  map <- map(x, cat_cols)
  df <- lapply(1:(base::length(cat_cols) - 1), function(i){
    df <- NULL
    df <- x %>%
      dplyr::group_by_(s = cat_cols[i], t = cat_cols[i + 1]) %>%
      dplyr::summarise_(.dots = stats::setNames(paste("sum(", num_col, ",na.rm = TRUE)", sep = ""), "total")) %>%
      dplyr::left_join(map %>% dplyr::select(s = cat, source = index), by = "s") %>%
      dplyr::left_join(map %>% dplyr::select(t = cat, target = index), by = "t")

    return(df)
  }) %>% dplyr::bind_rows()


  p <- plotly::plot_ly(
    type = "sankey",
    orientation = "h",
    valueformat = ".0f",
    node = list(
      label = map$cat,
      # color = c("blue", "green", "red", "purple", "yellow", "black"),
      pad = 15,
      thickness = 30,
      line = list(
        color = "black",
        width = 0.5
      )
    ),

    link = list(
      source = df$source,
      target = df$target,
      value =  df$total
    )
  )

  if(!is.null(title)){
    p <- p %>% plotly::layout(title = title)
  }

  return(p)
}

Try the sfo package in your browser

Any scripts or data that you put into this service are public.

sfo documentation built on March 31, 2023, 8:32 p.m.