R/jsd_discrete.R

Defines functions jsd_discrete

Documented in jsd_discrete

#' Estimate JSD for discrete variables
#'
#' Computes Jensen-Shannon divergence (JSD) between two discrete variables using
#' empirical probability mass functions.
#'
#' @param x Vector for group 1. Can be numeric, factor, character, or logical.
#' @param y Vector for group 2. Can be numeric, factor, character, or logical.
#' @param support Optional support values. If `NULL`, the union of observed
#'   values in `x` and `y` is used.
#' @param base Logarithm base. Defaults to 2. Use `exp(1)` for nats.
#' @param eps Small constant for numerical stability.
#' @param add_smoothing Logical; add 1 to each cell count?
#' @param na_rm Logical; remove missing values?
#'
#' @return An object of class `"jsd_estimate"`.
#' @export
jsd_discrete <- function(x, y,
                         support = NULL,
                         base = 2,
                         eps = 1e-12,
                         add_smoothing = FALSE,
                         na_rm = TRUE) {
  check_base(base)

  cleaned <- validate_xy(x, y, min_n = 1, na_rm = na_rm, finite_only = FALSE)
  x <- cleaned$x
  y <- cleaned$y

  support <- make_support(x, y, support = support)

  x_fac <- factor(as.character(x), levels = support)
  y_fac <- factor(as.character(y), levels = support)

  counts_x <- as.numeric(table(x_fac))
  counts_y <- as.numeric(table(y_fac))

  if (add_smoothing) {
    counts_x <- counts_x + 1
    counts_y <- counts_y + 1
  }

  p <- counts_x / sum(counts_x)
  q <- counts_y / sum(counts_y)

  p_safe <- pmax(p, eps)
  q_safe <- pmax(q, eps)
  m_safe <- pmax(0.5 * (p_safe + q_safe), eps)

  estimate <- 0.5 * sum(p_safe * safe_log_base(p_safe / m_safe, base = base)) +
    0.5 * sum(q_safe * safe_log_base(q_safe / m_safe, base = base))

  out <- list(
    estimate = unname(estimate),
    type = "discrete",
    method = "PMF",
    base = base,
    n_x = length(x),
    n_y = length(y),
    support = support,
    p = stats::setNames(p, support),
    q = stats::setNames(q, support),
    counts_x = stats::setNames(counts_x, support),
    counts_y = stats::setNames(counts_y, support),
    settings = list(
      eps = eps,
      add_smoothing = add_smoothing
    )
  )

  class(out) <- c("jsd_estimate", "list")
  out
}

Try the jsdtools package in your browser

Any scripts or data that you put into this service are public.

jsdtools documentation built on March 31, 2026, 1:06 a.m.