R/plot_discrete.R

Defines functions plot_discrete

Documented in plot_discrete

#' Plot two discrete distributions with overlap
#'
#' @param x Vector for group 1.
#' @param y Vector for group 2.
#' @param support Optional support values.
#' @param group_names Group labels.
#' @param main Plot title.
#' @param xlab X-axis label.
#' @param ylab Y-axis label.
#' @param col_x Color for group 1.
#' @param col_y Color for group 2.
#' @param overlap_col Fill color for overlap bars.
#' @param line_col_x Line color for group 1.
#' @param line_col_y Line color for group 2.
#' @param lwd Line width.
#' @param pch Point character.
#' @param cex_pt Point size.
#' @param las Axis label style for x-axis.
#' @param bar_width Width of overlap bars.
#' @param show_jsd Logical; whether to display JSD on the plot.
#' @param jsd_digits Number of digits for displayed JSD.
#' @param na_rm Logical; remove missing values?
#'
#' @return Invisibly returns plotting data.
#' @export
#' @importFrom grDevices rgb adjustcolor
#' @importFrom graphics axis box hist legend lines par points rect segments text
#' @importFrom stats density
plot_discrete <- function(x, y,
                          support = NULL,
                          group_names = c("Group 1", "Group 2"),
                          main = "Two-group discrete distributions",
                          xlab = "Value",
                          ylab = "Proportion",
                          col_x = adjustcolor("#2F5FB3", alpha.f = 0.20),
                          col_y = adjustcolor("#CC3333", alpha.f = 0.20),
                          overlap_col = adjustcolor("grey55", alpha.f = 0.35),
                          line_col_x = "#2F5FB3",
                          line_col_y = "#CC3333",
                          lwd = 2,
                          pch = 16,
                          cex_pt = 1.1,
                          las = 1,
                          bar_width = 0.2,
                          show_jsd = TRUE,
                          jsd_digits = 3,
                          na_rm = TRUE) {
  cleaned <- validate_xy(x, y, min_n = 1, na_rm = na_rm, finite_only = FALSE)
  x <- cleaned$x
  y <- cleaned$y

  support <- make_support(x, y, support = support)

  x_fac <- factor(as.character(x), levels = support)
  y_fac <- factor(as.character(y), levels = support)

  tx <- table(x_fac)
  ty <- table(y_fac)

  px <- as.numeric(tx) / sum(tx)
  py <- as.numeric(ty) / sum(ty)
  poverlap <- pmin(px, py)

  xpos <- seq_along(support)
  ymax <- max(c(px, py)) * 1.18

  jsd_value <- NA_real_
  if (show_jsd) {
    jsd_value <- tryCatch(
      jsd(x, y)$estimate,
      error = function(e) NA_real_
    )
  }

  plot(xpos, px,
       type = "n",
       xaxt = "n",
       yaxt = "n",
       ylim = c(0, ymax),
       xlim = c(min(xpos) - 0.35, max(xpos) + 0.35),
       xlab = xlab,
       ylab = ylab,
       main = main)

  axis(2, las = 1)
  axis(1, at = xpos, labels = support, las = las)
  box(bty = "l")

  rect(xleft  = xpos - bar_width / 2,
       ybottom = 0,
       xright = xpos + bar_width / 2,
       ytop   = poverlap,
       col = overlap_col,
       border = NA)

  segments(x0 = xpos - 0.1, y0 = 0, x1 = xpos - 0.1, y1 = px,
           col = line_col_x, lwd = 1.8)
  segments(x0 = xpos + 0.1, y0 = 0, x1 = xpos + 0.1, y1 = py,
           col = line_col_y, lwd = 1.8)

  lines(xpos, px, type = "l", lwd = lwd, col = line_col_x)
  lines(xpos, py, type = "l", lwd = lwd, col = line_col_y)

  points(xpos, px, pch = pch, cex = cex_pt, col = line_col_x)
  points(xpos, py, pch = pch, cex = cex_pt, col = line_col_y)

  legend("topright",
         legend = c(group_names[1], group_names[2], "Overlap"),
         col = c(line_col_x, line_col_y, overlap_col),
         lwd = c(lwd, lwd, NA),
         pch = c(pch, pch, 15),
         pt.cex = c(cex_pt, cex_pt, 2.2),
         bty = "n",
         inset = 0.01)

  if (show_jsd && is.finite(jsd_value)) {
    usr <- par("usr")
    text(x = usr[1] + 0.03 * (usr[2] - usr[1]),
         y = usr[4] - 0.05 * (usr[4] - usr[3]),
         labels = paste0("JSD = ", formatC(jsd_value, digits = jsd_digits, format = "f")),
         adj = c(0, 1),
         cex = 0.95)
  }

  invisible(data.frame(
    support = support,
    p_x = px,
    p_y = py,
    overlap = poverlap,
    jsd = jsd_value,
    stringsAsFactors = FALSE
  ))
}

Try the jsdtools package in your browser

Any scripts or data that you put into this service are public.

jsdtools documentation built on March 31, 2026, 1:06 a.m.