R/jsd_continuous.R

Defines functions jsd_continuous fixed_range

Documented in fixed_range jsd_continuous

#' Fixed integration range for continuous JSD
#'
#' @param x Numeric vector for group 1.
#' @param y Numeric vector for group 2.
#' @param qrange Quantile range used to determine the main data span.
#' @param extend Extension multiplier based on IQR.
#'
#' @return Named numeric vector with elements `L` and `U`.
#' @export
fixed_range <- function(x, y,
                        qrange = c(0.001, 0.999),
                        extend = 3) {
  cleaned <- validate_xy(x, y, min_n = 2, na_rm = TRUE, finite_only = TRUE)
  x <- cleaned$x
  y <- cleaned$y
  z <- c(x, y)

  qs <- stats::quantile(z, probs = qrange, names = FALSE, type = 7)
  iqr <- stats::IQR(z)

  if (!is.finite(iqr) || iqr == 0) {
    iqr <- stats::sd(z)
  }
  if (!is.finite(iqr) || iqr == 0) {
    iqr <- 1
  }

  L <- qs[1] - extend * iqr
  U <- qs[2] + extend * iqr

  if (!is.finite(L) || !is.finite(U) || L >= U) {
    sdz <- stats::sd(z)
    if (!is.finite(sdz) || sdz == 0) {
      sdz <- 1
    }
    L <- min(z) - 3 * sdz
    U <- max(z) + 3 * sdz
  }

  c(L = L, U = U)
}


#' Estimate JSD for continuous variables
#'
#' Computes Jensen-Shannon divergence (JSD) between two numeric vectors using
#' kernel density estimation (KDE) and numerical integration.
#'
#' @param x Numeric vector for group 1.
#' @param y Numeric vector for group 2.
#' @param L Optional lower integration bound.
#' @param U Optional upper integration bound.
#' @param base Logarithm base. Defaults to 2. Use `exp(1)` for nats.
#' @param bw Bandwidth passed to [stats::density()].
#' @param kernel Kernel passed to [stats::density()].
#' @param grid_n Number of grid points used for KDE.
#' @param qrange Quantile range used when `L` and `U` are not supplied.
#' @param extend Extension multiplier for the automatically chosen range.
#' @param eps Small constant for numerical stability.
#' @param renormalize Logical; renormalize estimated densities over the grid?
#' @param na_rm Logical; remove missing values?
#'
#' @return An object of class `"jsd_estimate"`.
#' @export
jsd_continuous <- function(x, y,
                           L = NULL,
                           U = NULL,
                           base = 2,
                           bw = "nrd0",
                           kernel = "gaussian",
                           grid_n = 4096,
                           qrange = c(0.001, 0.999),
                           extend = 3,
                           eps = 1e-12,
                           renormalize = TRUE,
                           na_rm = TRUE) {
  check_base(base)

  cleaned <- validate_xy(x, y, min_n = 2, na_rm = na_rm, finite_only = TRUE)
  x <- cleaned$x
  y <- cleaned$y

  if (is.null(L) || is.null(U)) {
    ru <- fixed_range(x, y, qrange = qrange, extend = extend)
    L <- unname(ru["L"])
    U <- unname(ru["U"])
  }

  if (!is.finite(L) || !is.finite(U) || L >= U) {
    stop("L and U must be finite and satisfy L < U.")
  }

  dx <- stats::density(x, bw = bw, kernel = kernel, from = L, to = U, n = grid_n)
  dy <- stats::density(y, bw = bw, kernel = kernel, from = L, to = U, n = grid_n)

  grid <- dx$x
  p <- pmax(dx$y, eps)
  q <- pmax(dy$y, eps)

  if (renormalize) {
    p <- p / trapz_num(grid, p)
    q <- q / trapz_num(grid, q)
  }

  m <- pmax(0.5 * (p + q), eps)

  estimate <- 0.5 * trapz_num(grid, p * safe_log_base(p / m, base = base)) +
    0.5 * trapz_num(grid, q * safe_log_base(q / m, base = base))

  out <- list(
    estimate = unname(estimate),
    type = "continuous",
    method = "KDE",
    base = base,
    n_x = length(x),
    n_y = length(y),
    range = c(L = L, U = U),
    settings = list(
      bw = bw,
      kernel = kernel,
      grid_n = grid_n,
      qrange = qrange,
      extend = extend,
      eps = eps,
      renormalize = renormalize
    )
  )

  class(out) <- c("jsd_estimate", "list")
  out
}

Try the jsdtools package in your browser

Any scripts or data that you put into this service are public.

jsdtools documentation built on March 31, 2026, 1:06 a.m.