R/BarPlotCat.R

Defines functions BarPlotCat

Documented in BarPlotCat

#' Plot Bar Charts of Forecast Category Probabilities
#'
#' This function creates bar plots for probabilistic forecasts split into 
#' categories (e.g., below-normal, normal, above-normal), optionally including
#' extreme categories (e.g., below P10, above P90). Probabilities are displayed 
#' on the y-axis. The function supports multi-panel plotting for different time
#' steps, the addition of skill-based transparency, a shared legend and axis
#' title, and output to file.
#'
#' @param probs A named 2D array with dimensions 'cat_dim' × 'panel_dim', 
#'  containing probabilities for each category (in relative units, summing to 1).
#'  Can also be a named vector (recycled to 2D internally).
#' @param lims A named 2D array with category threshold values (e.g., tercile 
#'  cutoffs), or a vector with length one less than the category dimension of
#'  'probs'. Should match the shape and names of 'probs'.
#' @param extreme_probs (optional) A 2D array (or vector) with probabilities for
#'  extreme categories (e.g., below P10, above P90), with dimensions matching 
#'  'probs'.
#' @param extreme_lims (optional) A 2D array (or vector) with limits corresponding
#'  to the extreme categories.
#' @param skill (optional) A numeric vector indicating skill scores for each 
#'  panel. Used to reduce bar transparency if negative.
#' @param toptitle A character string with the main plot title.
#' @param legend_title A string with the title of the legend that shows the 
#'  categories.
#' @param cat_dim A string indicating the name of the category dimension. 
#'  Default is "cat".
#' @param panel_dim A string indicating the name of the panel dimension (e.g., 
#'  time). Default is "ftime".
#' @param color.set A string selecting the color palette to use. One of 
#'  \code{"s2s4e"}, \code{"ggplot"}, \code{"hydro"}, \code{"vitigeoss"}. Default
#'  is "s2s4e".
#' @param category_names A character vector with names for each forecast category.
#'  Default is \code{c("bn", "norn", "an")}.
#' @param panel_title A string or vector of strings with titles for each panel.
#'  If one value, it is repeated across panels.
#' @param panel_subtitle A string or vector with subtitles for each panel. If one
#'  value, it is repeated across panels.
#' @param panel_bottom_name A string or vector for the bottom axis label of each
#'  panel.
#' @param lims_pos A numeric value indicating the vertical position of threshold
#'  annotations. Default is -1.5.
#' @param legend_width A numeric value (in cm) for the width of the legend area.
#'  Default is 3.
#' @param extreme_bars_width A numeric value between 0 and 1 defining the width
#'  of the bars for extreme categories. Default is 0.4.
#' @param xaxis_title A string for the shared x-axis title.
#' @param extreme_cat_names A character vector with names for the extreme
#'  categories (left and right). Default is \code{c("p10", "p90")}.
#' @param toptitle_size A numeric value for the font size of the top title.
#'  Default is 16.
#' @param toptitle_pos A string for justification of the title. One of 
#'  \code{"left"}, \code{"center"}, \code{"right"}. Default is "center".
#' @param fileout (optional) Path to save the resulting plot (e.g.,
#'  \code{"plot.png"}). If NULL, the plot is returned.
#' @param width (optional) Width of the output figure, passed to 
#'  \code{ggsave()}. Units are defined by 'size_units'.
#' @param height (optional) Height of the output figure, passed to 
#'  \code{ggsave()}. Units are defined by 'size_units'.
#' @param size_units Units for \code{width} and \code{height}. Default is inches.
#'  See \code{ggsave()}.
#' @param res Plot resolution in dpi when saving to file. Default is NULL.
#'
#' @return A grob object representing the composed bar plot, invisibly if written
#'  to file.
#'
#' @examples
#' # Basic example
#' probs <- array(rep(c(0.3, 0.4, 0.3), 4), c(cat = 3, ftime = 4))
#' lims <- array(rep(c(2, 4), 4), c(cat = 2, ftime = 4))
#' BarPlotCat(probs, lims, toptitle = "Example Forecast")
#'
#' @import ggplot2
#' @importFrom s2dv Reorder
#' @importFrom gridExtra grid.arrange arrangeGrob
#' @importFrom gtable gtable_filter
#' @importFrom grid textGrob unit unit.c gpar
#' @importFrom scales hue_pal
#' @export

BarPlotCat <- function(probs, lims, extreme_probs = NULL, extreme_lims = NULL,
                       skill = NULL, toptitle = '', legend_title = '', 
                       cat_dim = 'cat', panel_dim = 'ftime', color.set = 's2s4e',
                       category_names = c('bn', 'norn', 'an'),
                       panel_title = '', panel_subtitle = '',
                       panel_bottom_name = '', lims_pos = -4,
                       legend_width = 3.5, extreme_bars_width = 0.4,
                       xaxis_title = '', extreme_cat_names = c('p10', 'p90'),
                       toptitle_size = 16, toptitle_pos = "center",
                       fileout = NULL, width = 8, height = 6,
                       size_units = NULL, res = 100) {
  # Accept vectors:
  if (!is.array(probs)) {
    if (!is.array(lims)) {
      if (length(probs) != (length(lims) + 1)) {
        stop("Parameter 'probs' must have one more element than parameter 'lims'.")
      }
      dim(probs) <- c(length(probs), 1)
      dim(lims) <- c(length(lims), 1)
      names(dim(probs)) <- c(cat_dim, panel_dim)
      names(dim(lims)) <- c(cat_dim, panel_dim)
    } else {
      stop("Parameter 'probs' and 'lims' must be both vectors or arrays.")
    }
  }
  # Check skill:
  if (!is.null(skill)) {
    if (!is.numeric(skill)) {
      stop("Parameter 'skill' must be a numeric vector.")
    }
    skill <- as.vector(skill)
    if (length(skill) != dim(probs)[2]) {
      stop("Parameter 'skill' must be of the length of 'panel_dim' dimension in 'probs'.")
    }
  }
  # Check dim names:
  if (length(dim(lims)) == 1) {
    if (length(dim(probs)) == 1) {
      if (names(dim(lims)) == cat_dim) {
        dim(lims) <- c(cat_dim = length(lims), panel_dim = 1)
        names(dim(lims)) <- c(cat_dim, panel_dim)
      } else {
        stop("'cat_dim' in parameter lims is needed.")
      }
    } else {
      stop("Expected both 'lims' and 'probs' parameter have same dimensions.")
    }
    if (names(dim(probs)) == cat_dim) {
      dim(probs) <- c(cat_dim = length(probs), panel_dim = 1)
      names(dim(probs)) <- c(cat_dim, panel_dim)
    } else {
      stop("'cat_dim' in parameter 'probs' is needed.")
    }
    if (dim(probs)[1] != (dim(lims)[1] + 1)) {
      stop("Parameter 'probs' must have one more element than parameter 'lims' in 'cat_dim' dimension.")
    }
    if (dim(probs)[2] != dim(lims)[2]) {
      stop("Parameter 'probs' and 'lims' must have the same length for 'panel_dim' dimension.")
    }
  } else if (length(dim(lims)) == 2) {
    if (!(any(c(cat_dim, panel_dim) %in% names(dim(lims))))) {
      stop("Parameter 'lims' must have 'cat_dim' and 'panel_dim' dimensions.")
    }
    lims <- Reorder(lims, c(cat_dim, panel_dim))
    if (length(dim(probs)) == 2) {
      if (!(any(c(cat_dim, panel_dim) %in% names(dim(probs))))) { 
        stop("Parameter 'probs' must have 'cat_dim' and 'panel_dim' dimensions.")
      }
    } else {
      stop("Parameter 'probs' must have 2 dimensions maximum.")
    }
    probs <- Reorder(probs, c(cat_dim, panel_dim))
    if (dim(probs)[1] != (dim(lims)[1] + 1)) {
      stop("Parameter 'probs' must have one more element than parameter 'lims' in 'cat_dim' dimension.")
    }
    if (dim(probs)[2] != dim(lims)[2]) {
      stop("Parameter 'probs' and 'lims' must have the same length for 'panel_dim' dimension.")
    }
  } else {
    stop("Parameter 'lims' must have 2 dimensions maximum.")
  } 
  # Top title:
  if (!is.character(toptitle)) {
    stop("Parameter 'toptitle' must be a character string.")
  }
  # Legend title:
  if (!is.character(legend_title)) {
    stop("Parameter 'legend_title' must be a character string.")
  }
  legend_title <- paste0(legend_title, "  ") # Add buffer
  # Category names:
  if (!is.character(category_names)) {
    stop("Parameter 'category_names' must be a character vector.")
  }
  if (length(category_names) != dim(probs)[1]) {
    category_names <- paste("Cat", 1:dim(probs)[1])
    warning("The number of 'category_names' do not match the 'cat_dim' in 'probs'.")
  }
  # Panel title:
  if (!is.character(panel_title)) {
    stop("Parameter 'panel_title' must be a character string or vector of strings.")
  }
  if (length(panel_title) == 1) {
    panel_title <- rep(panel_title, dim(probs)[2])
  } else  if (length(panel_title) != dim(probs)[2]) {
    stop("Length of 'panel_title' must be equal to dimension panel_dim in array objects.")
  }
  panel <- 1:dim(probs)[2]
  # Extremes:
  if (!is.null(extreme_probs)) { 
    if (!is.array(extreme_probs)) {
      dim(extreme_probs) <- c(length(extreme_probs), 1)
      names(dim(extreme_probs)) <- c(cat_dim, panel_dim)
    }
    if (length(dim(extreme_probs)) == 2) {
      extreme_probs <- Reorder(extreme_probs, c(cat_dim, panel_dim))
    } else {
      stop("Unnexpected number of dimensions in 'extreme_probs'.")
    }
    if (dim(probs)[2] != dim(extreme_probs)[2]) {
      stop("'probs' dimensions do no match 'extreme_probs'.")
    }
  }
  if (!is.null(extreme_lims)) {
    if (!is.array(extreme_lims)) {
      dim(extreme_lims) <- c(length(extreme_lims), 1)
      names(dim(extreme_lims)) <- c(cat_dim, panel_dim)
    }
    if (length(dim(extreme_lims)) == 2) {
      extreme_lims <- Reorder(extreme_lims, c(cat_dim, panel_dim))
    } else {
      stop("Unnexpected number of dimensions in 'extreme_lims'.")
    }
    if (dim(probs)[2] != dim(extreme_lims)[2]) {
      stop("'probs' dimensions do no match 'extreme_lims'.")
    }
  }
  # Panel subtitle:
  if (!is.character(panel_subtitle)) {
    stop("Parameter 'panel_subtitle' must be a character string or vector of strings.")
  }
  if (length(panel_subtitle) == 1) {
    panel_subtitle <- rep(panel_subtitle, dim(probs)[2])
  } else if (length(panel_subtitle) != dim(probs)[2]) {
    stop("Length of 'panel_subtitle' must be equal to dimension panel_dim in array objects.")
  }
  # Panel bottom name:
  if (!is.character(panel_bottom_name)) {
    stop("Parameter 'panel_bottom_name' must be a character string or vector of strings.")
  }
  if (length(panel_bottom_name) == 1) {
    panel_bottom_name <- rep(panel_bottom_name, dim(probs)[2])
  } else if (length(panel_bottom_name) != dim(probs)[2]) {
    stop("Length of 'panel_bottom_name' must be equal to dimension panel_dim in array objects.")
  }
  
  #------------------------
  # Define color sets
  #------------------------
  if (color.set == "s2s4e") {
    colorFill <- rev(c("#FF764D", "#b5b5b5", "#33BFD1")) 
    colorHatch <- rev(c("indianred3", "deepskyblue3"))
    colorMember <- c("#ffff7f")
    colorObs <- "purple"
    colorLab <- c("red", "blue") # AP90, BP10 text colors
  } else if (color.set == "hydro") {
    colorFill <- rev(c("#41CBC9", "#b5b5b5", "#FFAB38"))
    colorHatch <- rev(c("deepskyblue3", "darkorange1"))
    colorMember <- c("#ffff7f")
    colorObs <- "purple"
    colorLab <- c("blue", "darkorange3")
  } else if (color.set == "ggplot") {
    colorFill <- hue_pal()(3)
    colorHatch <- c("indianred3", "deepskyblue3")
    colorMember <- c("#ffff7f")
    colorObs <- "purple"
    colorLab <- c("red", "blue")
  } else if (color.set == "vitigeoss") {
    colorFill <- rev(c("#007be2", "#acb2b5", "#f40000"))
    colorHatch <- rev(c("#211b79", "#ae0003"))
    colorMember <- c("#ffff7f")
    colorObs <- "purple"
    colorLab <- colorHatch
  } else {
    stop("Parameter 'color.set' should be one of ggplot/s2s4e/hydro")
  }
  if (dim(probs)[1] != 3 || !is.null(extreme_probs)) {
    colorFill <- c(colorHatch[1], colorFill, colorHatch[2])
  }
  # Data frame for limits
  probs <- probs * 100
  dprobs <- data.frame(panel = unlist(lapply(panel, rep, dim(probs)[1])),
                       cat = rep(category_names, dim(probs)[2]),
                       probs = as.vector(probs))
  my_levels <- category_names
  ticks <- seq(2.5, dim(probs)[1] + 1, 1)
  extremes_shift_pos <- 0
  # extremes:
  if (!is.null(extreme_probs)) {
    extreme_probs <- extreme_probs * 100
    extreme_probs <- data.frame(
      panel = unlist(lapply(panel, rep, dim(extreme_probs)[1])),
      cat = rep(extreme_cat_names, dim(extreme_probs)[2]),
      probs = as.vector(extreme_probs))
    dprobs <- rbind(dprobs, extreme_probs)
    my_levels <- c(extreme_cat_names[1], category_names, extreme_cat_names[2])
    extremes_shift_pos <- (1 - extreme_bars_width)/2
    if(!is.null(extreme_lims)) {
      ticks <- seq(1.5, dim(probs)[1] + 2, 1)
      lims <- rbind(extreme_lims[1,], lims, extreme_lims[2,])
    } 
  }
  dprobs$cat <- factor(dprobs$cat, levels = my_levels)
  # x coordinates of bar plot labels
  x_vals <- if (is.null(extreme_probs)) {
    2:(length(my_levels) + 1)
  } else {
    c(1.5 - extreme_bars_width/2, 
      2:(length(my_levels) - 1), 
      (length(my_levels) - 0.5) + extreme_bars_width/2)
  }
  
  # Create plot for individual panels
  plot_list <- lapply(1:dim(probs)[2], function(p) {
    tmp <- dprobs[dprobs$panel == panel[p], ]
    if (!is.null(skill)) {
      if (skill[p] < 0) {
        colorFill <- adjustcolor(colorFill, alpha.f = 0.5)
        labskill <- 'Skill < 0'
      } else {
        labskill <- ''
      }
    } else {
      labskill <- ''
    }
    ggplot(tmp, aes(x = factor(cat, levels = my_levels), y = probs, fill = cat)) +
      # First geom_bar() for the tercile categories
      geom_bar(data = subset(tmp, cat %in% category_names),
               stat = "identity", position = "dodge",
               width = 1) +
      # Extreme categories (First and Last Bars)
      geom_bar(data = subset(tmp, cat == extreme_cat_names[1]),
               aes(x = as.numeric(factor(cat, levels = my_levels)) + extremes_shift_pos,  # Shift right
                   y = probs, fill = cat),
               stat = "identity", width = extreme_bars_width,
               position = position_identity()) +
      
      geom_bar(data = subset(tmp, cat == extreme_cat_names[2]),
               aes(x = as.numeric(factor(cat, levels = my_levels)) - extremes_shift_pos,  # Shift left
                   y = probs, fill = cat),
               stat = "identity", width = extreme_bars_width,
               position = position_identity()) +
      scale_x_discrete(limits = c(extreme_cat_names[1], category_names,
                                  extreme_cat_names[2])) + 
      theme_minimal() +
      scale_fill_manual(name = legend_title,
                        values = colorFill, drop = F) +
      scale_y_continuous(expand = c(0, 0), limits = c(lims_pos, 100)) +  
      coord_cartesian(clip = "off") +
      labs(x = panel_bottom_name[p], 
           y = "", 
           title = panel_title[p],
           subtitle = panel_subtitle[p]) +
      theme(axis.text.x = element_blank(), 
            axis.title.x = element_blank(),
            axis.text.y = element_blank(), 
            axis.title.y = element_blank(), 
            legend.position = "none") + 
      # x-axis labels
      annotate("text",  ticks,
               label = lims[,p], vjust = 0, color = "black",
               size = 4.5, y = lims_pos) +
      # labels on top of color bars
      annotate("text", 
               x = x_vals,
               y = pmin(tmp$probs[match(my_levels, tmp$cat)] + 2, 99.9),
               label = paste0(" ", round(tmp$probs[match(my_levels, tmp$cat)], 0), "% "), 
               color = "black", vjust = 0,
               size = 4.1) +
      annotate("text", x = 3, y = 15, label = labskill,
               color = "black", size = 12, fontface = 2) 
  })
  
  # Create a dummy plot for the common y-axis
  common_y_axis <- ggplot() +
    geom_bar(aes(x = factor(1), y = 0), stat = "identity", fill = "white") + 
    labs(x = "", y = "", title = "", subtitle = "") +
    scale_y_continuous(limits = c(lims_pos, 100), expand = c(0,0)) +
    theme_minimal() +
    theme(axis.text.x = element_blank(), 
          axis.title.x = element_blank(),
          axis.text.y = element_text(color = "black", size = 12),  # y-axis labels
          axis.title.y = element_text(color = "black", size = 14),  # y-axis title
          panel.grid = element_blank()) +
    labs(y = "Probability (%)")  # Adding y-axis title to the common y-axis
  # Extract the common y-axis from the plot
  common_y_axis <- ggplotGrob(common_y_axis)
  common_y_axis$heights <- ggplotGrob(plot_list[[1]])$heights 
  # Create the common legend separately
  common_legend <- ggplot(subset(dprobs, panel == panel[1]),
                          aes(x = cat, y = 1, fill = cat)) +
    geom_bar(position = "dodge", stat = "identity", width = 1) +
    scale_fill_manual(name = legend_title, values = colorFill) +
    theme_minimal() +
    theme(legend.position = "right",  # Position the legend on the right
          legend.title = element_text(size = 14),
          legend.text = element_text(size = 12),  # Legend text
          axis.text = element_blank(),  # Hide axis texts
          axis.title = element_blank(),  # Hide axis titles
          panel.grid = element_blank())  # Remove grid lines
  
  # Extract the legend from the common legend plot
  common_legend <- gtable::gtable_filter(ggplotGrob(common_legend), "guide-box")
  # Create a left-aligned title
  toptitle <- textGrob(toptitle,
                       gp = gpar(fontsize = toptitle_size, fontface = "bold"),
                       just = toptitle_pos)
  # Create a common x-axis title
  if (dim(probs)[2] == 1) {
    xaxis_title <- textGrob(xaxis_title, gp = gpar(fontsize = 12), vjust = 0.5, x = unit(0.49, "npc"))
  } else {
    xaxis_title <- textGrob(xaxis_title, gp = gpar(fontsize = 12), vjust = 0.5)
  }
  # Define total available width excluding legend
  plot_widths <- rep(1, dim(probs)[2])
  # Combine the plots and the legend using grid.arrange
  plot <- grid.arrange(arrangeGrob(
    grobs = c(list(common_y_axis), plot_list, list(common_legend)), 
    ncol = dim(probs)[2] + 2,
    top = toptitle,
    widths = unit.c(
      unit(if (dim(probs)[2] == 1) 0.2 else 0.4, "null"),
      unit(plot_widths, "null"), 
      unit(legend_width, "cm"))),
    bottom = xaxis_title)
  if (!is.null(fileout)) {
    ggsave(fileout, plot = plot, width = width, height = height, dpi = res)
  } else {
    plot
  }
}

Try the esviz package in your browser

Any scripts or data that you put into this service are public.

esviz documentation built on Feb. 4, 2026, 5:13 p.m.