R/get_prior.R

Defines functions get_prior

Documented in get_prior

#' @title Construct Prior Parameters for BKP/DKP Models
#'
#' @description
#' Computes prior parameters for Beta Kernel Process (BKP, binary)
#' or Dirichlet Kernel Process (DKP, multi-class) models.
#' Supports three prior strategies: noninformative, fixed, adaptive.
#'
#' @details
#' Prior strategies:
#' * `noninformative`: flat prior (Beta(1,1) or Dirichlet(1,...,1)).
#' * `fixed`: global constant prior.
#' * `adaptive`: kernel-smoothed local prior, estimated from observed data.
#'
#' @param prior Character string; prior type.
#'   One of: `"noninformative"`, `"fixed"`, `"adaptive"`.
#' @param model Character string; model type.
#'   One of: `"BKP"` (binary), `"DKP"` (multi-class).
#' @param r0 Numeric; prior precision (positive scalar, default = 2).
#' @param p0 Numeric; global prior mean.
#'   BKP: scalar in (0,1); DKP: vector summing to 1.
#' @param y Numeric vector; observed success counts (BKP only).
#' @param m Numeric vector; number of trials (BKP only, same length as `y`).
#' @param Y Numeric matrix; observed class counts (DKP only, n × q).
#' @param K Numeric matrix; precomputed kernel matrix.
#'
#' @return
#' For BKP: a list with `alpha0` and `beta0`.
#' For DKP: a matrix `alpha0` of prior Dirichlet parameters.
#'
#' @examples
#' \donttest{
#' # BKP example
#' set.seed(123)
#' n <- 10
#' X <- matrix(runif(n*2), ncol = 2)
#' y <- rbinom(n, size = 5, prob = 0.6)
#' m <- rep(5, n)
#' K <- matrix(1, n, n)
#' prior_bkp <- get_prior(
#'   model = "BKP", prior = "adaptive",
#'   r0 = 2, y = y, m = m, K = K
#' )
#' }
#'
#' @export
get_prior <- function(prior = c("noninformative", "fixed", "adaptive"),
                      model = c("BKP", "DKP"),
                      r0 = 2, p0 = NULL, y = NULL, m = NULL, Y = NULL, K = NULL)
{
  model <- match.arg(model)
  prior <- match.arg(prior)
  
  if (!is.numeric(r0) || length(r0) != 1 || r0 <= 0) {
    stop("'r0' must be a positive scalar.")
  }
  
  if (model == "BKP") {
    if (prior == "noninformative") {
      nrowK <- if (!is.null(K)) nrow(K) else 1
      alpha0 <- rep(1, nrowK)
      beta0  <- rep(1, nrowK)
    } else if (prior == "fixed") {
      if (!is.numeric(p0) || length(p0) != 1 || p0 <= 0 || p0 >= 1) {
        stop("For fixed prior in BKP, 'p0' must be in (0,1).")
      }
      nrowK <- if (!is.null(K)) nrow(K) else 1
      alpha0 <- rep(r0 * p0, nrowK)
      beta0  <- rep(r0 * (1 - p0), nrowK)
    } else if (prior == "adaptive") {
      if (is.null(y) || is.null(m) || is.null(K)) {
        stop("For adaptive prior in BKP, 'y', 'm', and 'K' must be provided.")
      }
      W <- K / pmax(rowSums(K), 1e-6)
      p_hat <- as.vector(W %*% (y / m))
      r_hat <- r0 * pmax(rowSums(K), 1e-3)
      alpha0 <- r_hat * p_hat
      beta0  <- r_hat * (1 - p_hat)
      alpha0 <- pmax(alpha0, 1e-2)
      beta0 <- pmax(beta0, 1e-2)
    }
    return(list(alpha0 = alpha0, beta0 = beta0))
  } else {
    if (!is.null(Y)) {
      q <- ncol(as.matrix(Y))
    } else if (!is.null(p0)) {
      q <- length(p0)
    } else {
      stop("Either 'Y' or 'p0' must be provided.")
    }
    
    if (prior == "noninformative") {
      m_k <- if (!is.null(K)) nrow(K) else 1
      alpha0 <- matrix(1, nrow = m_k, ncol = q)
    } else if (prior == "fixed") {
      if (is.null(p0)) stop("'p0' must be provided for fixed prior in DKP.")
      m_k <- if (!is.null(K)) nrow(K) else 1
      alpha0 <- matrix(rep(r0 * p0, each = m_k), nrow = m_k, byrow = TRUE)
    } else if (prior == "adaptive") {
      if (is.null(Y) || is.null(K)) stop("'Y' and 'K' must be provided.")
      W <- K / pmax(rowSums(K), 1e-6)
      Pi_hat <- W %*% (Y / rowSums(Y))
      r_hat <- r0 * pmax(rowSums(K), 1e-3)
      alpha0 <- Pi_hat * r_hat
    }
    alpha0 <- pmax(alpha0, 1e-2)
    return(alpha0)
  }
}

Try the NBKP package in your browser

Any scripts or data that you put into this service are public.

NBKP documentation built on June 18, 2026, 1:06 a.m.