R/bcat_plt_coef.R

Defines functions bcat_plt_coef

Documented in bcat_plt_coef

#' Coefficient plot (forest plot)
#'
#' Visualize regression coefficients and confidence intervals for one or more
#' models. Uses \code{broom::tidy()} to extract estimates.
#'
#' @param models A model object or a named list of model objects.
#' @param conf_level Numeric. Confidence level. Default is 0.95.
#' @param intercept Logical. Include intercept? Default is FALSE.
#' @param coef_rename Named character vector to rename coefficients.
#'   If NULL (default), auto-cleans names to Title Case.
#' @param highlight Character vector of term names to highlight in UC Red.
#' @param dodge_width Numeric. Dodge width for multi-model comparison. Default is 0.4.
#' @param x_lab Label for x-axis (coefficient values).
#' @param y_lab Label for y-axis (term names).
#' @param title Plot title.
#' @param subtitle Plot subtitle.
#' @param caption Plot caption.
#' @param legend_lab Legend title.
#' @param legend_position Legend position.
#' @param legend_hide Logical. Hide legend?
#' @param color_scale \code{scale_color_} function.
#' @return A ggplot object.
#' @author Saannidhya Rawat
#' @family plots
#' @export
#'
#' @examples
#' library(ggplot2)
#'
#' m <- lm(mpg ~ wt + hp + cyl, data = mtcars)
#' bcat_plt_coef(m)
#'
#' m1 <- lm(mpg ~ wt + hp, data = mtcars)
#' m2 <- lm(mpg ~ wt + hp + cyl + disp, data = mtcars)
#' bcat_plt_coef(list("Base" = m1, "Full" = m2))
bcat_plt_coef <- function(models,
                          conf_level = 0.95,
                          intercept = FALSE,
                          coef_rename = NULL,
                          highlight = NULL,
                          dodge_width = 0.4,
                          x_lab = "Estimate",
                          y_lab = ggplot2::waiver(),
                          title = ggplot2::waiver(),
                          subtitle = ggplot2::waiver(),
                          caption = ggplot2::waiver(),
                          legend_lab = "Model",
                          legend_position = "bottom",
                          legend_hide = FALSE,
                          color_scale = scale_colour_UC()) {

  # Wrap single model
  single_model <- FALSE
  if (!is.list(models) || inherits(models, "lm") || inherits(models, "glm")) {
    models <- list("Model" = models)
    single_model <- TRUE
  }
  if (is.null(names(models))) names(models) <- paste("Model", seq_along(models))

  # Tidy all models
  tidy_list <- lapply(names(models), function(nm) {
    td <- broom::tidy(models[[nm]], conf.int = TRUE, conf.level = conf_level)
    td$model <- nm
    td
  })
  tidy_df <- do.call(rbind, tidy_list)

  # Filter intercept
  if (!intercept) {
    tidy_df <- tidy_df[tidy_df$term != "(Intercept)", ]
  }

  # Rename coefficients
  if (is.null(coef_rename)) {
    tidy_df$term <- gsub("[_.]", " ", tidy_df$term)
    tidy_df$term <- tools::toTitleCase(tidy_df$term)
  } else {
    idx <- match(tidy_df$term, names(coef_rename))
    matched <- !is.na(idx)
    tidy_df$term[matched] <- coef_rename[idx[matched]]
  }

  multi_model <- length(unique(tidy_df$model)) > 1L
  pos <- if (multi_model) ggplot2::position_dodge(width = dodge_width) else ggplot2::position_identity()

  if (multi_model) {
    p <- ggplot2::ggplot(tidy_df,
                         ggplot2::aes(x = estimate,
                                      y = stats::reorder(term, estimate),
                                      xmin = conf.low, xmax = conf.high,
                                      color = model))
  } else {
    p <- ggplot2::ggplot(tidy_df,
                         ggplot2::aes(x = estimate,
                                      y = stats::reorder(term, estimate),
                                      xmin = conf.low, xmax = conf.high))
  }

  p <- p +
    ggplot2::geom_vline(xintercept = 0, linetype = "dashed",
                        color = .uc_reference_color(), linewidth = 0.5) +
    ggplot2::geom_pointrange(position = pos, linewidth = 0.6, size = 0.4)

  if (multi_model) p <- p + color_scale

  # Highlight specific terms
  if (!is.null(highlight)) {
    hl_df <- tidy_df[tidy_df$term %in% highlight, ]
    if (nrow(hl_df) > 0L) {
      p <- p + ggplot2::geom_pointrange(
        data = hl_df,
        ggplot2::aes(x = estimate, y = stats::reorder(term, estimate),
                     xmin = conf.low, xmax = conf.high),
        color = .uc_color("UC Red"),
        position = pos, linewidth = 0.8, size = 0.5,
        inherit.aes = FALSE
      )
    }
  }

  p <- p + ggplot2::labs(x = x_lab, y = y_lab, title = title,
                         subtitle = subtitle, caption = caption,
                         color = legend_lab)

  p + theme_UC_vgrid(legend_position = legend_position,
                     legend_hide = if (single_model) TRUE else legend_hide)
}

Try the Rbearcat package in your browser

Any scripts or data that you put into this service are public.

Rbearcat documentation built on March 21, 2026, 5:07 p.m.