R/plot_mid_breakdown.R

Defines functions plot.mid.breakdown

Documented in plot.mid.breakdown

#' Plot MID Breakdowns
#'
#' @description
#' For "mid.breakdown" objects, \code{plot()} visualizes the breakdown of a prediction by component functions.
#'
#' @details
#' This is an S3 method for the \code{plot()} generic that produces a breakdown plot from a "mid.breakdown" object, visualizing the contribution of each component function to a single prediction.
#'
#' The \code{type} argument controls the visualization style.
#' The default, \code{type = "waterfall"}, creates a waterfall plot that shows how the prediction builds from the intercept, with each term's contribution sequentially added or subtracted.
#' The \code{type = "barplot"} option creates a standard bar plot where the length of each bar represents the magnitude of the term's contribution.
#' The \code{type = "dotchart"} option creates a dot plot showing the contribution of each term as a point connected to a zero baseline.
#'
#' @param x a "mid.breakdown" object to be visualized.
#' @param type the plotting style. One of "waterfall", "barplot" or "dotchart".
#' @param theme a character string or object defining the color theme. See \code{\link{color.theme}} for details.
#' @param terms an optional character vector specifying which terms to display.
#' @param max.nterms the maximum number of terms to display in the plot. Less important terms will be grouped into a "catchall" category.
#' @param width a numeric value specifying the width of the bars.
#' @param vline logical. If \code{TRUE}, a vertical line is drawn at the zero or intercept line.
#' @param catchall a character string for the catchall label.
#' @param format a character string or character vector of length two to be used as the format of the axis labels. Use "\%t" for the term name (e.g., "carat") and "\%v" for the values (e.g., "0.23").
#' @param ... optional parameters passed on to the graphing function. Possible arguments are "col", "fill", "pch", "cex", "lty", "lwd" and aliases of them.
#'
#' @examples
#' data(diamonds, package = "ggplot2")
#' set.seed(42)
#' idx <- sample(nrow(diamonds), 1e4)
#' mid <- interpret(price ~ (carat + cut + color + clarity)^2, diamonds[idx, ])
#' mbd <- mid.breakdown(mid, diamonds[1L, ])
#'
#' # Create a waterfall plot
#' plot(mbd, type = "waterfall")
#'
#' # Create a bar plot with a different theme
#' plot(mbd, type = "barplot", theme = "highlight")
#'
#' # Create a dot chart
#' plot(mbd, type = "dotchart", size = 1.5)
#' @returns
#' \code{plot.mid.breakdown()} produces a plot as a side effect and returns \code{NULL} invisibly.
#'
#' @seealso \code{\link{mid.breakdown}}, \code{\link{ggmid.mid.breakdown}}
#'
#' @exportS3Method base::plot
#'
plot.mid.breakdown <- function(
    x, type = c("waterfall", "barplot", "dotchart"), theme = NULL,
    terms = NULL, max.nterms = 15L, width = NULL, vline = TRUE,
    catchall = "others", format = c("%t=%v", "%t"), ...) {
  dots <- list(...)
  type <- match.arg(type)
  if (missing(theme))
    theme <- getOption("midr.sequential", getOption("midr.qualitative", NULL))
  theme <- color.theme(theme)
  use.theme <- inherits(theme, "color.theme")
  bd <- x$breakdown
  bd$term <- as.character(bd$term)
  if (any(!grepl("%t", format) & !grepl("%v", format)))
    stop("all format strings must contain at least one of '%t' or '%v'")
  if (length(format) == 1L)
    format <- c(format, format)
  use.catchall <- FALSE
  if (!is.null(terms)) {
    rowid <- match(terms, bd$term, nomatch = 0L)
    resid <- sum(bd[-rowid, "mid"])
    bd <- bd[rowid, ]
    bd[nrow(bd) + 1L, "mid"] <- resid
    use.catchall <- TRUE
  }
  nmax <- min(max.nterms, nrow(bd), na.rm = TRUE)
  if (nmax < nrow(bd)) {
    resid <- sum(bd[nmax:nrow(bd), "mid"])
    bd <- bd[1L:(nmax - 1L), ]
    bd[nmax, "mid"] <- resid
    use.catchall <- TRUE
  }
  for (i in seq_len(nrow(bd) - as.numeric(use.catchall))) {
    term <- bd[i, "term"]
    fmt <- if (grepl(":", term)) format[2L] else format[1L]
    bd[i, "term"] <-
      gsub("%v", bd[i, "value"], gsub("%t", bd[i, "term"], fmt))
  }
  if (use.catchall)
    bd[nrow(bd), "term"] <- catchall
  if (type == "barplot" || type == "dotchart") {
    args <- list(to = bd$mid, labels = bd$term,
                 horizontal = TRUE, xlab = "mid")
    cols <- if (use.theme) {
      if (theme$type == "qualitative")
        to.colors(bd$order, theme)
      else
        to.colors(bd$mid, theme)
    } else "gray35"
    if (type == "dotchart") {
      args$type <- "d"
      args$col <- cols
    } else {
      args$type <- "b"
      args$fill <- cols
      args$width <- ifnot.null(width, .8)
    }
    args <- override(args, dots)
    do.call(barplot2, args)
    if (vline)
      graphics::abline(v = 0)
  } else if (type == "waterfall") {
    cols <- if (use.theme) {
      if (theme$type == "qualitative")
        to.colors(bd$mid > 0, theme)
      else
        to.colors(bd$mid, theme)
    } else "gray35"
    width <- ifnot.null(width, .6)
    hw <- width / 2
    n <- nrow(bd)
    cs <- cumsum(c(x$intercept, bd$mid))
    bd$xmin <- cs[1L:n]
    bd$xmax <- cs[2L:(n + 1L)]
    args <- list(to = bd$xmax, from = bd$xmin, labels = bd$term, type = "b",
                 fill = cols, horizontal = TRUE, xlab = "mid", width = width,
                 lty = 1L, lwd = 1L, col = NULL)
    args <- override(args, dots)
    do.call(barplot2, args)
    for (i in seq_len(n)) {
      graphics::lines.default(x = rep.int(bd[i, "xmax"], 2L),
                              y = c(n + 1 - i + hw, max(n - i - hw, 1 - hw)),
                              col = ifnot.null(args$col, 1L),
                              lty = args$lty, lwd = args$lwd)
    }
    if (vline)
      graphics::abline(v = x$intercept, lty = 3L)
  }
  invisible(NULL)
}

Try the midr package in your browser

Any scripts or data that you put into this service are public.

midr documentation built on Sept. 11, 2025, 1:07 a.m.