Nothing
.static_prior_or <- function(x, alt) {
if (!is.null(x)) x else alt
}
.static_is_rhs_family <- function(beta_prior_type) {
tolower(as.character(beta_prior_type)[1]) %in% c("rhs", "rhs_ns")
}
.static_match_beta_prior <- function(beta_prior) {
prior <- tolower(as.character(beta_prior)[1])
if (!nzchar(prior)) prior <- "ridge"
if (identical(prior, "gaussian")) prior <- "ridge"
match.arg(prior, c("ridge", "rhs", "rhs_ns"))
}
.static_parse_beta_prior_controls <- function(beta_prior_controls = NULL, prior_type = c("rhs", "rhs_ns")) {
prior_type <- match.arg(prior_type)
default_n_inner <- if (identical(prior_type, "rhs_ns")) 2L else 1L
ctrl <- list(
tau0 = 1,
nu = 4,
s = NULL,
s2 = 1,
a_zeta = if (identical(prior_type, "rhs_ns")) 2 else NULL,
b_zeta = if (identical(prior_type, "rhs_ns")) 1 else NULL,
zeta2_fixed = NULL,
shrink_intercept = FALSE,
intercept_prec = 1e-16,
n_inner = default_n_inner,
freeze_tau_iters = 50L,
freeze_tau_warmup_iters = NULL,
update_every = 1L,
update_every_warmup = 1L,
update_every_warmup_iters = 0L,
force_tau_after_warmup = TRUE,
collapse_tau_ratio_tol = 1e-6,
collapse_beta_max_abs_tol = 1e-6,
collapse_invV_med_tol = 1e8,
collapse_beta_l2_tol = 1e-6,
collapse_small_beta_frac_tol = 0.95,
small_beta_abs_tol = 1e-4,
warn_on_collapse = TRUE,
eta_bounds = list(
lambda = c(-40, 40),
tau = c(-40, 40),
c2 = c(-40, 40)
),
var_floor = 1e-16,
h_curv = 1e-16,
verbose = FALSE,
init_lambda = 1,
init_log_lambda = NULL,
init_tau = NULL,
init_log_tau = NULL,
init_c2 = NULL,
init_log_c2 = NULL,
slice_width = 1,
slice_max_steps = 20L
)
if (!is.null(beta_prior_controls)) {
if (!is.list(beta_prior_controls)) stop("beta_prior_controls must be a list.")
ctrl <- utils::modifyList(ctrl, beta_prior_controls)
}
input_names <- names(beta_prior_controls %||% list())
has_input <- function(nm) nm %in% input_names
tau0 <- as.numeric(ctrl$tau0)[1]
nu <- as.numeric(ctrl$nu)[1]
if (!is.finite(tau0) || tau0 <= 0) stop("beta_prior_controls$tau0 must be > 0.")
if (!is.finite(nu) || nu <= 0) stop("beta_prior_controls$nu must be > 0.")
has_s <- !is.null(ctrl$s)
has_s2 <- !is.null(ctrl$s2)
s_val <- if (has_s) as.numeric(ctrl$s)[1] else NA_real_
s2_val <- if (has_s2) as.numeric(ctrl$s2)[1] else NA_real_
if (has_s2) {
s2 <- s2_val
s <- sqrt(s2)
s_source <- "s2"
} else if (has_s) {
s <- s_val
s2 <- s^2
s_source <- "s"
} else {
s2 <- 1
s <- 1
s_source <- "default"
}
if (!is.finite(s2) || s2 <= 0) stop("beta_prior_controls$s2 (or s^2) must be > 0.")
if (has_s && has_s2) {
s_from_s2 <- sqrt(s2_val)
if (is.finite(s_val) && is.finite(s_from_s2)) {
rel <- abs(s_val - s_from_s2) / max(1, abs(s_val), abs(s_from_s2))
if (rel > 1e-8) {
warning("beta_prior_controls supplied both s and s2 inconsistently; using s2 and resetting s=sqrt(s2).", call. = FALSE)
}
}
}
if (identical(prior_type, "rhs_ns")) {
a_zeta <- as.numeric(ctrl$a_zeta)[1]
b_zeta <- as.numeric(ctrl$b_zeta)[1]
if ((has_input("a_zeta") || has_input("b_zeta")) && (!is.finite(a_zeta) || a_zeta <= 0 || !is.finite(b_zeta) || b_zeta <= 0)) {
stop("beta_prior_controls$a_zeta and beta_prior_controls$b_zeta must be finite and > 0 when provided.")
}
explicit_ab <- has_input("a_zeta") || has_input("b_zeta")
explicit_nu_s <- has_input("nu") || has_input("s") || has_input("s2")
if (explicit_ab && !explicit_nu_s) {
nu <- 2 * a_zeta
s2 <- b_zeta / a_zeta
s <- sqrt(s2)
s_source <- "rhs_ns_ab"
} else {
a_from_rhs <- nu / 2
b_from_rhs <- (nu * s2) / 2
if (explicit_ab) {
rel_a <- abs(a_zeta - a_from_rhs) / max(1, abs(a_zeta), abs(a_from_rhs))
rel_b <- abs(b_zeta - b_from_rhs) / max(1, abs(b_zeta), abs(b_from_rhs))
if (max(rel_a, rel_b) > 1e-8) {
warning(
"beta_prior_controls supplied both RHS and RHS_NS slab controls inconsistently; using nu/s/s2 and mapping to a_zeta/b_zeta.",
call. = FALSE
)
}
}
a_zeta <- a_from_rhs
b_zeta <- b_from_rhs
}
ctrl$a_zeta <- a_zeta
ctrl$b_zeta <- b_zeta
} else {
ctrl$a_zeta <- NULL
ctrl$b_zeta <- NULL
}
parse_bounds <- function(x, default) {
x <- as.numeric(.static_prior_or(x, default))
if (length(x) != 2L || any(!is.finite(x)) || x[1] >= x[2]) default else x
}
ctrl$tau0 <- tau0
ctrl$nu <- nu
ctrl$s <- s
ctrl$s2 <- s2
ctrl$s_source <- s_source
zeta2_fixed <- .static_prior_or(ctrl$zeta2_fixed, ctrl$c2_fixed)
if (!is.null(zeta2_fixed)) {
zeta2_fixed <- as.numeric(zeta2_fixed)[1]
if (!is.finite(zeta2_fixed) || zeta2_fixed <= 0) {
stop("beta_prior_controls$zeta2_fixed must be finite and > 0 when provided.")
}
}
ctrl$zeta2_fixed <- zeta2_fixed
ctrl$prior_type <- prior_type
ctrl$shrink_intercept <- isTRUE(ctrl$shrink_intercept)
ctrl$intercept_prec <- as.numeric(ctrl$intercept_prec)[1]
if (!is.finite(ctrl$intercept_prec) || ctrl$intercept_prec <= 0) ctrl$intercept_prec <- 1e-16
if (is.null(ctrl$eta_bounds) || !is.list(ctrl$eta_bounds)) ctrl$eta_bounds <- list()
ctrl$eta_bounds$lambda <- parse_bounds(.static_prior_or(ctrl$eta_bounds$lambda, ctrl$lambda_bounds), c(-40, 40))
ctrl$eta_bounds$tau <- parse_bounds(.static_prior_or(ctrl$eta_bounds$tau, ctrl$tau_bounds), c(-40, 40))
ctrl$eta_bounds$c2 <- parse_bounds(.static_prior_or(ctrl$eta_bounds$c2, ctrl$c2_bounds), c(-40, 40))
ctrl$n_inner <- suppressWarnings(as.integer(ctrl$n_inner)[1])
if (!is.finite(ctrl$n_inner) || ctrl$n_inner < 1L) ctrl$n_inner <- 1L
ctrl$var_floor <- as.numeric(ctrl$var_floor)[1]
if (!is.finite(ctrl$var_floor) || ctrl$var_floor <= 0) ctrl$var_floor <- 1e-16
ctrl$h_curv <- as.numeric(ctrl$h_curv)[1]
if (!is.finite(ctrl$h_curv) || ctrl$h_curv <= 0) ctrl$h_curv <- 1e-16
ctrl$verbose <- isTRUE(ctrl$verbose)
freeze_tau_iters <- suppressWarnings(as.integer(.static_prior_or(ctrl$freeze_tau_iters, ctrl$rhs_freeze_tau_iters))[1])
if (!is.finite(freeze_tau_iters) || freeze_tau_iters < 0L) freeze_tau_iters <- 50L
freeze_tau_warmup_iters <- suppressWarnings(as.integer(
.static_prior_or(
ctrl$freeze_tau_warmup_iters,
.static_prior_or(ctrl$rhs_freeze_tau_warmup_iters, freeze_tau_iters)
)[1]
))
if (!is.finite(freeze_tau_warmup_iters) || freeze_tau_warmup_iters < 0L) {
freeze_tau_warmup_iters <- freeze_tau_iters
}
update_every <- suppressWarnings(as.integer(.static_prior_or(ctrl$update_every, ctrl$rhs_update_every))[1])
if (!is.finite(update_every) || update_every < 1L) update_every <- 1L
update_every_warmup <- suppressWarnings(as.integer(
.static_prior_or(ctrl$update_every_warmup, .static_prior_or(ctrl$rhs_update_every_warmup, update_every))[1]
))
if (!is.finite(update_every_warmup) || update_every_warmup < 1L) update_every_warmup <- update_every
update_every_warmup_iters <- suppressWarnings(as.integer(
.static_prior_or(ctrl$update_every_warmup_iters, ctrl$rhs_update_every_warmup_iters)[1]
))
if (!is.finite(update_every_warmup_iters) || update_every_warmup_iters < 0L) update_every_warmup_iters <- 0L
ctrl$freeze_tau_iters <- freeze_tau_iters
ctrl$freeze_tau_warmup_iters <- freeze_tau_warmup_iters
ctrl$update_every <- update_every
ctrl$update_every_warmup <- update_every_warmup
ctrl$update_every_warmup_iters <- update_every_warmup_iters
ctrl$force_tau_after_warmup <- isTRUE(.static_prior_or(ctrl$force_tau_after_warmup, ctrl$rhs_force_tau_after_warmup))
ctrl$collapse_tau_ratio_tol <- as.numeric(ctrl$collapse_tau_ratio_tol)[1]
if (!is.finite(ctrl$collapse_tau_ratio_tol) || ctrl$collapse_tau_ratio_tol <= 0) ctrl$collapse_tau_ratio_tol <- 1e-6
ctrl$collapse_beta_max_abs_tol <- as.numeric(ctrl$collapse_beta_max_abs_tol)[1]
if (!is.finite(ctrl$collapse_beta_max_abs_tol) || ctrl$collapse_beta_max_abs_tol <= 0) ctrl$collapse_beta_max_abs_tol <- 1e-6
ctrl$warn_on_collapse <- isTRUE(ctrl$warn_on_collapse)
ctrl$collapse_invV_med_tol <- as.numeric(ctrl$collapse_invV_med_tol)[1]
if (!is.finite(ctrl$collapse_invV_med_tol) || ctrl$collapse_invV_med_tol <= 0) ctrl$collapse_invV_med_tol <- 1e8
ctrl$collapse_beta_l2_tol <- as.numeric(ctrl$collapse_beta_l2_tol)[1]
if (!is.finite(ctrl$collapse_beta_l2_tol) || ctrl$collapse_beta_l2_tol <= 0) ctrl$collapse_beta_l2_tol <- 1e-6
ctrl$collapse_small_beta_frac_tol <- as.numeric(ctrl$collapse_small_beta_frac_tol)[1]
if (!is.finite(ctrl$collapse_small_beta_frac_tol) || ctrl$collapse_small_beta_frac_tol <= 0 || ctrl$collapse_small_beta_frac_tol > 1) {
ctrl$collapse_small_beta_frac_tol <- 0.95
}
ctrl$small_beta_abs_tol <- as.numeric(ctrl$small_beta_abs_tol)[1]
if (!is.finite(ctrl$small_beta_abs_tol) || ctrl$small_beta_abs_tol <= 0) ctrl$small_beta_abs_tol <- 1e-4
init_lambda <- if (!is.null(ctrl$init_log_lambda)) .static_rhs_safe_exp(ctrl$init_log_lambda) else ctrl$init_lambda
init_lambda <- as.numeric(init_lambda)
if (!length(init_lambda) || any(!is.finite(init_lambda)) || any(init_lambda <= 0)) {
stop("beta_prior_controls$init_lambda must be finite and > 0 (scalar or length p).")
}
ctrl$init_lambda <- init_lambda
has_init_log_tau <- !is.null(ctrl$init_log_tau)
has_init_tau <- !is.null(ctrl$init_tau)
if (has_init_log_tau) {
resolved_init_log_tau <- as.numeric(ctrl$init_log_tau)[1]
if (!is.finite(resolved_init_log_tau)) stop("beta_prior_controls$init_log_tau must be finite when provided.")
resolved_init_tau <- .static_rhs_safe_exp(resolved_init_log_tau)
init_tau_source <- "init_log_tau"
} else if (has_init_tau) {
resolved_init_tau <- as.numeric(ctrl$init_tau)[1]
if (!is.finite(resolved_init_tau) || resolved_init_tau <= 0) stop("beta_prior_controls$init_tau must be finite and > 0 when provided.")
resolved_init_log_tau <- log(resolved_init_tau)
init_tau_source <- "init_tau"
} else {
# Guardrail: unset/null tau init defaults to log(1)=0, not tau0.
resolved_init_log_tau <- 0
resolved_init_tau <- 1
init_tau_source <- "default_log_tau_0"
}
if (!is.finite(resolved_init_log_tau) || !is.finite(resolved_init_tau) || resolved_init_tau <= 0) {
stop("Resolved RHS tau initialization is invalid; check init_log_tau/init_tau controls.")
}
ctrl$init_log_tau <- resolved_init_log_tau
ctrl$init_tau <- resolved_init_tau
ctrl$init_tau_source <- init_tau_source
init_c2 <- if (!is.null(ctrl$init_log_c2)) .static_rhs_safe_exp(ctrl$init_log_c2) else ctrl$init_c2
if (!is.null(init_c2)) {
init_c2 <- as.numeric(init_c2)[1]
if (!is.finite(init_c2) || init_c2 <= 0) stop("beta_prior_controls$init_c2 must be finite and > 0 when provided.")
}
if (!is.null(ctrl$zeta2_fixed)) {
init_c2 <- as.numeric(ctrl$zeta2_fixed)[1]
}
ctrl$init_c2 <- init_c2
ctrl$init_log_c2 <- if (!is.null(init_c2)) log(init_c2) else NULL
ctrl$slice_width <- as.numeric(ctrl$slice_width)[1]
if (!is.finite(ctrl$slice_width) || ctrl$slice_width <= 0) ctrl$slice_width <- 1
ctrl$slice_max_steps <- suppressWarnings(as.integer(ctrl$slice_max_steps)[1])
if (!is.finite(ctrl$slice_max_steps) || ctrl$slice_max_steps < 1L) ctrl$slice_max_steps <- 20L
ctrl
}
.static_rhs_opt_1d_mode <- function(f, lo, hi, eta0 = NULL) {
if (!is.finite(lo) || !is.finite(hi) || lo >= hi) stop("Invalid optimization bounds.")
if (is.null(eta0) || !is.finite(eta0)) eta0 <- 0
opt <- try(stats::optimize(f, interval = c(lo, hi), maximum = TRUE), silent = TRUE)
if (!inherits(opt, "try-error") && is.finite(opt$objective)) {
return(pmin(pmax(opt$maximum, lo), hi))
}
grid <- seq(lo, hi, length.out = 31L)
vals <- vapply(grid, f, numeric(1))
idx <- which.max(vals)
par0 <- grid[idx]
fn_neg <- function(z) {
v <- f(z)
if (!is.finite(v)) 1e100 else -v
}
opt2 <- try(stats::optim(par = par0, fn = fn_neg, method = "BFGS", control = list(maxit = 2000L)), silent = TRUE)
if (inherits(opt2, "try-error") || !is.finite(opt2$value)) {
return(par0)
}
pmin(pmax(as.numeric(opt2$par)[1], lo), hi)
}
.static_rhs_active_idx <- function(p, shrink_intercept) {
if (isTRUE(shrink_intercept)) {
seq_len(p)
} else if (p >= 2L) {
2L:p
} else {
integer(0)
}
}
.static_rhs_safe_exp <- function(x) {
exp(pmin(pmax(as.numeric(x), -745), 709))
}
.static_rhs_preflight_config <- function(ctrl) {
cfg <- list(
prior_type = as.character(.static_prior_or(ctrl$prior_type, "rhs"))[1],
tau0 = as.numeric(ctrl$tau0)[1],
nu = as.numeric(ctrl$nu)[1],
s = as.numeric(ctrl$s)[1],
s2 = as.numeric(ctrl$s2)[1],
a_zeta = as.numeric(.static_prior_or(ctrl$a_zeta, NA_real_))[1],
b_zeta = as.numeric(.static_prior_or(ctrl$b_zeta, NA_real_))[1],
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
init_log_tau = as.numeric(ctrl$init_log_tau)[1],
init_tau = as.numeric(ctrl$init_tau)[1],
init_tau_source = as.character(.static_prior_or(ctrl$init_tau_source, "unknown"))[1],
eta_bounds_tau = as.numeric(ctrl$eta_bounds$tau),
shrink_intercept = isTRUE(ctrl$shrink_intercept)
)
if (!is.finite(cfg$init_log_tau) || !is.finite(cfg$init_tau) || cfg$init_tau <= 0) {
stop("RHS preflight failed: resolved init_log_tau/init_tau must be finite with init_tau > 0.")
}
if (length(cfg$eta_bounds_tau) != 2L || any(!is.finite(cfg$eta_bounds_tau)) || cfg$eta_bounds_tau[1] >= cfg$eta_bounds_tau[2]) {
stop("RHS preflight failed: eta_bounds$tau must be finite and ordered.")
}
cfg
}
.static_rhs_preflight_emit <- function(cfg, context = "rhs") {
msg <- sprintf(
"[%s] RHS preflight | tau0=%.6g nu=%.6g s=%.6g s2=%.6g init_log_tau=%.6g init_tau=%.6g init_source=%s eta_bounds_tau=[%.6g, %.6g] shrink_intercept=%s",
context,
cfg$tau0, cfg$nu, cfg$s, cfg$s2, cfg$init_log_tau, cfg$init_tau, cfg$init_tau_source,
cfg$eta_bounds_tau[1], cfg$eta_bounds_tau[2], ifelse(cfg$shrink_intercept, "TRUE", "FALSE")
)
message(msg)
invisible(msg)
}
.static_rhs_log1p_exp <- function(x) {
x <- as.numeric(x)
out <- numeric(length(x))
pos <- x > 0
out[pos] <- x[pos] + log1p(.static_rhs_safe_exp(-x[pos]))
out[!pos] <- log1p(.static_rhs_safe_exp(x[!pos]))
out
}
.static_rhs_logsumexp2 <- function(a, b) {
m <- pmax(a, b)
m + log(.static_rhs_safe_exp(a - m) + .static_rhs_safe_exp(b - m))
}
.static_rhs_obj_eta <- function(eta_lambda, eta_tau, eta_c2, beta2, ctrl) {
idx <- .static_rhs_active_idx(length(beta2), ctrl$shrink_intercept)
eta_lambda_use <- eta_lambda[idx]
beta2_use <- beta2[idx]
u <- 2 * eta_tau + 2 * eta_lambda_use
log_invV <- .static_rhs_logsumexp2(-eta_c2, -u)
invV <- .static_rhs_safe_exp(log_invV)
logV <- eta_c2 + u - .static_rhs_logsumexp2(eta_c2, u)
quad <- beta2_use * invV
quad[!is.finite(quad)] <- .Machine$double.xmax
like <- -0.5 * sum(logV + quad)
lp_lam <- sum(eta_lambda_use - .static_rhs_log1p_exp(2 * eta_lambda_use))
lp_tau <- eta_tau - .static_rhs_log1p_exp(2 * (eta_tau - log(ctrl$tau0)))
lp_c2 <- -(ctrl$nu / 2) * eta_c2 - (ctrl$nu * ctrl$s2) / (2 * .static_rhs_safe_exp(eta_c2))
out <- like + lp_lam + lp_tau + lp_c2
if (!is.finite(out)) -1e300 else out
}
.static_rhs_d2_lambda_j <- function(eta_lambda_j, eta_tau, eta_c2, beta2_j) {
u <- 2 * (eta_tau + eta_lambda_j)
a <- eta_c2
ld <- .static_rhs_logsumexp2(a, u)
w <- .static_rhs_safe_exp(u - ld)
w1w <- w * (1 - w)
t <- .static_rhs_safe_exp(-u)
d2_like <- 2 * w1w - 2 * beta2_j * t
s <- stats::plogis(2 * eta_lambda_j)
d2_prior <- -4 * s * (1 - s)
d2_like + d2_prior
}
.static_rhs_d2_tau <- function(eta_lambda_use, eta_tau, eta_c2, beta2_use, tau0) {
u <- 2 * eta_tau + 2 * eta_lambda_use
a <- eta_c2
ld <- .static_rhs_logsumexp2(a, u)
w <- .static_rhs_safe_exp(u - ld)
w1w <- w * (1 - w)
t <- .static_rhs_safe_exp(-u)
d2_like <- sum(2 * w1w - 2 * beta2_use * t)
s <- stats::plogis(2 * (eta_tau - log(tau0)))
d2_prior <- -4 * s * (1 - s)
d2_like + d2_prior
}
.static_rhs_d2_c2 <- function(eta_lambda_use, eta_tau, eta_c2, beta2_use, nu, s2) {
u <- 2 * eta_tau + 2 * eta_lambda_use
a <- eta_c2
ld <- .static_rhs_logsumexp2(a, u)
w <- .static_rhs_safe_exp(u - ld)
w1w <- w * (1 - w)
r <- .static_rhs_safe_exp(-eta_c2)
d2_like <- 0.5 * sum(w1w) - 0.5 * r * sum(beta2_use)
d2_prior <- -(nu * s2) / 2 * r
d2_like + d2_prior
}
.static_rhs_hess_active <- function(eta_lambda, eta_tau, eta_c2, beta2, ctrl) {
idx <- .static_rhs_active_idx(length(beta2), ctrl$shrink_intercept)
k <- length(idx)
d <- k + 2L
H <- matrix(0, d, d)
itau <- k + 1L
ikap <- k + 2L
logr <- -eta_c2
r <- .static_rhs_safe_exp(logr)
s_tau <- stats::plogis(2 * (eta_tau - log(ctrl$tau0)))
d2_tau_prior <- -4 * s_tau * (1 - s_tau)
d2_kap_prior <- -(ctrl$nu * ctrl$s2) / 2 * r
for (a in seq_len(k)) {
j <- idx[a]
uj <- eta_lambda[j]
Sj <- beta2[j]
logt <- -2 * (uj + eta_tau)
t <- .static_rhs_safe_exp(logt)
logg <- .static_rhs_logsumexp2(logt, logr)
w_t <- .static_rhs_safe_exp(logt - logg)
w_r <- .static_rhs_safe_exp(logr - logg)
a_wr <- w_t * w_r
h11_log <- 2 * a_wr
h12_log <- 2 * a_wr
h13_log <- -a_wr
h22_log <- 2 * a_wr
h23_log <- -a_wr
h33_log <- 0.5 * a_wr
h11_g <- -0.5 * Sj * (4 * t)
h12_g <- -0.5 * Sj * (4 * t)
h22_g <- -0.5 * Sj * (4 * t)
h33_g <- -0.5 * Sj * r
sj <- stats::plogis(2 * uj)
d2_lam_prior <- -4 * sj * (1 - sj)
ia <- a
H[ia, ia] <- H[ia, ia] + h11_log + h11_g + d2_lam_prior
H[ia, itau] <- H[ia, itau] + h12_log + h12_g
H[itau, ia] <- H[ia, itau]
H[ia, ikap] <- H[ia, ikap] + h13_log
H[ikap, ia] <- H[ia, ikap]
H[itau, itau] <- H[itau, itau] + h22_log + h22_g
H[itau, ikap] <- H[itau, ikap] + h23_log
H[ikap, itau] <- H[itau, ikap]
H[ikap, ikap] <- H[ikap, ikap] + h33_log + h33_g
}
H[itau, itau] <- H[itau, itau] + d2_tau_prior
H[ikap, ikap] <- H[ikap, ikap] + d2_kap_prior
H <- 0.5 * (H + t(H))
list(H = H, idx = idx)
}
.static_rhs_inv_spd_with_jitter <- function(K, var_floor, max_tries = 12L) {
d <- nrow(K)
jitter <- 0
for (tt in seq_len(max_tries)) {
KK <- K
if (jitter > 0) diag(KK) <- diag(KK) + jitter
R <- try(chol(KK), silent = TRUE)
if (!inherits(R, "try-error")) {
inv <- chol2inv(R)
return(list(inv = inv, logdet = 2 * sum(log(diag(R)))))
}
jitter <- if (tt == 1L) max(var_floor, 1e-16) else jitter * 10
}
KK <- K + diag(max(1e-16, var_floor), d)
ev <- eigen(KK, symmetric = TRUE, only.values = TRUE)$values
ev <- pmax(ev, 1e-300)
list(inv = solve(KK), logdet = sum(log(ev)))
}
.static_rhs_embed_sigma_full <- function(Sigma_active, idx, p, var_floor) {
Sigma_full <- diag(var_floor, p + 2L)
k <- length(idx)
itauA <- k + 1L
ikapA <- k + 2L
if (k > 0) {
Sigma_full[idx, idx] <- Sigma_active[seq_len(k), seq_len(k), drop = FALSE]
Sigma_full[idx, p + 1L] <- Sigma_active[seq_len(k), itauA]
Sigma_full[p + 1L, idx] <- Sigma_active[itauA, seq_len(k)]
Sigma_full[idx, p + 2L] <- Sigma_active[seq_len(k), ikapA]
Sigma_full[p + 2L, idx] <- Sigma_active[ikapA, seq_len(k)]
}
Sigma_full[p + 1L, p + 1L] <- Sigma_active[itauA, itauA]
Sigma_full[p + 2L, p + 2L] <- Sigma_active[ikapA, ikapA]
Sigma_full[p + 1L, p + 2L] <- Sigma_active[itauA, ikapA]
Sigma_full[p + 2L, p + 1L] <- Sigma_active[ikapA, itauA]
Sigma_full
}
.static_rhs_init_vb_state <- function(p, ctrl) {
lam0 <- ctrl$init_lambda
if (length(lam0) == 1L) lam0 <- rep(lam0, p)
lam0 <- pmax(as.numeric(lam0), 1e-16)
if (length(lam0) != p) stop("beta_prior_controls$init_lambda must be scalar or length p.")
tau0 <- as.numeric(ctrl$init_tau)[1]
c20 <- if (!is.null(ctrl$init_c2)) ctrl$init_c2 else ctrl$s2
list(
p = p,
shrink_intercept = ctrl$shrink_intercept,
intercept_prec = ctrl$intercept_prec,
eta_lambda_hat = log(lam0),
eta_tau_hat = log(pmax(tau0, 1e-16)),
eta_c_hat = log(pmax(c20, 1e-16)),
Sigma_full = diag(ctrl$var_floor, p + 2L),
Sigma_diag = rep(ctrl$var_floor, p + 2L),
iter = 0L,
freeze_tau = FALSE,
update_tau_only = FALSE,
tau_update_count = 0L,
has_post_warmup_tau_update = FALSE,
last_schedule = list(),
collapse_diag = NULL
)
}
.static_rhs_init_mcmc_state <- function(p, ctrl) {
lam0 <- ctrl$init_lambda
if (length(lam0) == 1L) lam0 <- rep(lam0, p)
lam0 <- pmax(as.numeric(lam0), 1e-16)
if (length(lam0) != p) stop("beta_prior_controls$init_lambda must be scalar or length p.")
tau0 <- as.numeric(ctrl$init_tau)[1]
c20 <- if (!is.null(ctrl$init_c2)) ctrl$init_c2 else ctrl$s2
list(
p = p,
shrink_intercept = ctrl$shrink_intercept,
intercept_prec = ctrl$intercept_prec,
lambda = lam0,
tau = pmax(as.numeric(tau0)[1], 1e-16),
c2 = pmax(as.numeric(c20)[1], 1e-16)
)
}
.static_rhs_expected_prec_vb <- function(state, ctrl) {
p <- state$p
mu_lam <- as.numeric(state$eta_lambda_hat)
mu_tau <- as.numeric(state$eta_tau_hat)
mu_c <- as.numeric(state$eta_c_hat)
Sigma_full <- state$Sigma_full
if (is.null(Sigma_full) || !all(dim(Sigma_full) == c(p + 2L, p + 2L))) {
Sigma_full <- diag(pmax(state$Sigma_diag, ctrl$var_floor), p + 2L)
}
var_kap <- max(Sigma_full[p + 2L, p + 2L], 0)
r_hat <- .static_rhs_safe_exp(-mu_c)
prec <- numeric(p)
for (j in seq_len(p)) {
if (!isTRUE(ctrl$shrink_intercept) && j == 1L) {
prec[j] <- ctrl$intercept_prec
next
}
t_hat <- .static_rhs_safe_exp(-2 * (mu_lam[j] + mu_tau))
v_sum <- Sigma_full[j, j] + Sigma_full[p + 1L, p + 1L] + 2 * Sigma_full[j, p + 1L]
v_sum <- max(v_sum, 0)
delta <- 0.5 * (4 * t_hat * v_sum + r_hat * var_kap)
prec[j] <- t_hat + r_hat + delta
}
pmax(prec, 1e-16)
}
.static_rhs_vb_schedule <- function(state, ctrl, iter_now) {
tau_warmup <- isTRUE(ctrl$freeze_tau_warmup_iters > 0L && iter_now <= ctrl$freeze_tau_warmup_iters)
update_every_eff <- ctrl$update_every
if (ctrl$update_every_warmup_iters > 0L && iter_now <= ctrl$update_every_warmup_iters) {
update_every_eff <- ctrl$update_every_warmup
}
scheduled_rhs_update <- (update_every_eff <= 1L) || ((iter_now %% update_every_eff) == 0L)
force_tau_now <- !tau_warmup &&
isTRUE(ctrl$force_tau_after_warmup) &&
isTRUE(ctrl$freeze_tau_warmup_iters > 0L) &&
!isTRUE(state$has_post_warmup_tau_update)
do_update <- isTRUE(scheduled_rhs_update || force_tau_now)
update_tau_only <- isTRUE(force_tau_now && !scheduled_rhs_update)
reason <- if (tau_warmup) {
"warmup"
} else if (update_tau_only) {
"force_after_warmup"
} else if (scheduled_rhs_update) {
"scheduled"
} else {
"rhs_update_skipped"
}
list(
iter = iter_now,
tau_warmup = tau_warmup,
update_every_eff = update_every_eff,
scheduled_rhs_update = scheduled_rhs_update,
force_tau_now = force_tau_now,
do_update = do_update,
update_tau_only = update_tau_only,
reason = reason
)
}
.static_rhs_collapse_diag <- function(state, qbeta, ctrl) {
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
beta_use <- if (length(idx)) as.numeric(qbeta$m)[idx] else numeric(0)
tau <- .static_rhs_safe_exp(state$eta_tau_hat)
log_tau <- as.numeric(state$eta_tau_hat)[1]
tau0 <- max(as.numeric(ctrl$tau0)[1], 1e-16)
tau_ratio <- tau / tau0
slope_l2 <- if (length(beta_use)) sqrt(sum(beta_use^2)) else NA_real_
slope_max_abs <- if (length(beta_use)) max(abs(beta_use)) else NA_real_
small_beta_abs_tol <- as.numeric(ctrl$small_beta_abs_tol)[1]
small_beta_frac <- if (length(beta_use)) {
mean(abs(beta_use) <= small_beta_abs_tol)
} else {
NA_real_
}
E_invV <- .static_rhs_expected_prec_vb(state, ctrl)
E_invV_use <- if (length(idx)) as.numeric(E_invV)[idx] else numeric(0)
E_invV_med <- if (length(E_invV_use) && any(is.finite(E_invV_use))) stats::median(E_invV_use[is.finite(E_invV_use)]) else NA_real_
tau_near_zero <- isTRUE(is.finite(tau_ratio) && tau_ratio <= ctrl$collapse_tau_ratio_tol) ||
isTRUE(abs(as.numeric(state$eta_tau_hat) - ctrl$eta_bounds$tau[1]) <= 1e-6)
slope_collapse <- isTRUE(is.finite(slope_max_abs) && slope_max_abs <= ctrl$collapse_beta_max_abs_tol)
precision_beta_pattern <- isTRUE(
is.finite(E_invV_med) && E_invV_med >= ctrl$collapse_invV_med_tol &&
is.finite(slope_l2) && slope_l2 <= ctrl$collapse_beta_l2_tol &&
is.finite(small_beta_frac) && small_beta_frac >= ctrl$collapse_small_beta_frac_tol
)
collapse_flag <- isTRUE((tau_near_zero && slope_collapse) || precision_beta_pattern)
warning_msg <- if (collapse_flag) {
if (isTRUE(precision_beta_pattern) && !isTRUE(tau_near_zero && slope_collapse)) {
paste(
"RHS shrinkage-collapse detected from precision/beta pattern",
"(large E_invV + tiny beta norm + high near-zero-beta fraction).",
"Consider revising tau initialization and RHS controls."
)
} else {
paste(
"RHS global scale collapsed near zero and active coefficients collapsed toward zero.",
"Consider a larger tau0 and/or RHS tau warmup/freeze tuning."
)
}
} else {
NA_character_
}
list(
collapse_flag = collapse_flag,
precision_beta_pattern = precision_beta_pattern,
tau_near_zero = tau_near_zero,
slope_collapse = slope_collapse,
beta_collapse = slope_collapse,
tau = tau,
log_tau = log_tau,
tau0 = tau0,
tau_ratio = tau_ratio,
E_invV_med = E_invV_med,
slope_l2 = slope_l2,
beta_l2 = slope_l2,
slope_max_abs = slope_max_abs,
small_beta_frac = small_beta_frac,
small_beta_abs_tol = small_beta_abs_tol,
active_count = length(idx),
eta_tau = as.numeric(state$eta_tau_hat)[1],
eta_tau_lower = ctrl$eta_bounds$tau[1],
at_tau_lower_bound = isTRUE(abs(as.numeric(state$eta_tau_hat)[1] - ctrl$eta_bounds$tau[1]) <= 1e-6),
warning = warning_msg
)
}
.static_rhs_collapse_diag_mcmc <- function(state, beta, ctrl) {
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
beta_use <- if (length(idx)) as.numeric(beta)[idx] else numeric(0)
tau <- pmax(as.numeric(state$tau)[1], 1e-16)
log_tau <- log(tau)
tau0 <- max(as.numeric(ctrl$tau0)[1], 1e-16)
tau_ratio <- tau / tau0
slope_l2 <- if (length(beta_use)) sqrt(sum(beta_use^2)) else NA_real_
slope_max_abs <- if (length(beta_use)) max(abs(beta_use)) else NA_real_
small_beta_abs_tol <- as.numeric(ctrl$small_beta_abs_tol)[1]
small_beta_frac <- if (length(beta_use)) mean(abs(beta_use) <= small_beta_abs_tol) else NA_real_
invV <- .static_rhs_prec_mcmc(state, ctrl)
invV_use <- if (length(idx)) as.numeric(invV)[idx] else numeric(0)
E_invV_med <- if (length(invV_use) && any(is.finite(invV_use))) stats::median(invV_use[is.finite(invV_use)]) else NA_real_
tau_near_zero <- isTRUE(is.finite(tau_ratio) && tau_ratio <= ctrl$collapse_tau_ratio_tol) ||
isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6))
slope_collapse <- isTRUE(is.finite(slope_max_abs) && slope_max_abs <= ctrl$collapse_beta_max_abs_tol)
precision_beta_pattern <- isTRUE(
is.finite(E_invV_med) && E_invV_med >= ctrl$collapse_invV_med_tol &&
is.finite(slope_l2) && slope_l2 <= ctrl$collapse_beta_l2_tol &&
is.finite(small_beta_frac) && small_beta_frac >= ctrl$collapse_small_beta_frac_tol
)
collapse_flag <- isTRUE((tau_near_zero && slope_collapse) || precision_beta_pattern)
warning_msg <- if (collapse_flag) {
if (isTRUE(precision_beta_pattern) && !isTRUE(tau_near_zero && slope_collapse)) {
paste(
"RHS shrinkage-collapse detected from precision/beta pattern",
"(large E_invV + tiny beta norm + high near-zero-beta fraction).",
"Consider revising tau initialization and RHS controls."
)
} else {
paste(
"RHS global scale collapsed near zero and active coefficients collapsed toward zero.",
"Consider a larger tau0 and/or RHS tau warmup/freeze tuning."
)
}
} else {
NA_character_
}
list(
collapse_flag = collapse_flag,
precision_beta_pattern = precision_beta_pattern,
tau_near_zero = tau_near_zero,
slope_collapse = slope_collapse,
beta_collapse = slope_collapse,
tau = tau,
log_tau = log_tau,
tau0 = tau0,
tau_ratio = tau_ratio,
E_invV_med = E_invV_med,
slope_l2 = slope_l2,
beta_l2 = slope_l2,
slope_max_abs = slope_max_abs,
small_beta_frac = small_beta_frac,
small_beta_abs_tol = small_beta_abs_tol,
active_count = length(idx),
eta_tau_lower = ctrl$eta_bounds$tau[1],
at_tau_lower_bound = isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6)),
warning = warning_msg
)
}
.static_rhs_maybe_warn_collapse <- function(rhs_summary, ctrl) {
if (!isTRUE(ctrl$warn_on_collapse) || is.null(rhs_summary) || !isTRUE(rhs_summary$collapse_flag)) return(invisible(FALSE))
warning(rhs_summary$warning, call. = FALSE)
invisible(TRUE)
}
.static_rhs_update_vb <- function(state, qbeta, ctrl) {
p <- state$p
beta2 <- as.numeric(qbeta$m^2 + diag(qbeta$V))
eta_lam <- as.numeric(state$eta_lambda_hat)
eta_tau <- as.numeric(state$eta_tau_hat)
eta_c2 <- as.numeric(state$eta_c_hat)
if (!is.null(ctrl$zeta2_fixed)) eta_c2 <- log(as.numeric(ctrl$zeta2_fixed)[1])
active_idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
iter_now <- as.integer(.static_prior_or(state$iter, 0L)) + 1L
sched <- .static_rhs_vb_schedule(state, ctrl, iter_now)
tau_updated <- FALSE
if (isTRUE(sched$do_update)) {
update_tau_only <- isTRUE(sched$update_tau_only)
for (inner in seq_len(ctrl$n_inner)) {
if (!isTRUE(update_tau_only)) {
for (j in active_idx) {
f_j <- function(eta_j) {
et <- eta_lam
et[j] <- eta_j
.static_rhs_obj_eta(et, eta_tau, eta_c2, beta2, ctrl)
}
eta_lam[j] <- .static_rhs_opt_1d_mode(
f_j,
lo = ctrl$eta_bounds$lambda[1],
hi = ctrl$eta_bounds$lambda[2],
eta0 = eta_lam[j]
)
}
}
if (!isTRUE(sched$tau_warmup)) {
f_tau <- function(etau) .static_rhs_obj_eta(eta_lam, etau, eta_c2, beta2, ctrl)
eta_tau <- .static_rhs_opt_1d_mode(
f_tau,
lo = ctrl$eta_bounds$tau[1],
hi = ctrl$eta_bounds$tau[2],
eta0 = eta_tau
)
tau_updated <- TRUE
}
if (!isTRUE(update_tau_only) && is.null(ctrl$zeta2_fixed)) {
f_c2 <- function(ec) .static_rhs_obj_eta(eta_lam, eta_tau, ec, beta2, ctrl)
eta_c2 <- .static_rhs_opt_1d_mode(
f_c2,
lo = ctrl$eta_bounds$c2[1],
hi = ctrl$eta_bounds$c2[2],
eta0 = eta_c2
)
}
}
hess <- .static_rhs_hess_active(eta_lam, eta_tau, eta_c2, beta2, ctrl)
invK <- .static_rhs_inv_spd_with_jitter(-hess$H, ctrl$var_floor)
Sigma_full <- .static_rhs_embed_sigma_full(invK$inv, idx = hess$idx, p = p, var_floor = ctrl$var_floor)
state$Sigma_full <- Sigma_full
state$Sigma_diag <- diag(Sigma_full)
}
state$eta_lambda_hat <- eta_lam
state$eta_tau_hat <- eta_tau
state$eta_c_hat <- if (!is.null(ctrl$zeta2_fixed)) log(as.numeric(ctrl$zeta2_fixed)[1]) else eta_c2
state$iter <- iter_now
state$freeze_tau <- isTRUE(sched$tau_warmup)
state$update_tau_only <- isTRUE(sched$update_tau_only)
state$tau_update_count <- as.integer(.static_prior_or(state$tau_update_count, 0L)) + if (tau_updated) 1L else 0L
state$has_post_warmup_tau_update <- isTRUE(.static_prior_or(state$has_post_warmup_tau_update, FALSE) || tau_updated)
state$last_schedule <- c(
sched,
list(
tau_updated = tau_updated,
tau_update_count = state$tau_update_count
)
)
state$collapse_diag <- .static_rhs_collapse_diag(state, qbeta, ctrl)
state
}
.static_rhs_elbo_vb <- function(state, qbeta, ctrl) {
p <- state$p
beta2 <- as.numeric(qbeta$m^2 + diag(qbeta$V))
eta_lam <- as.numeric(state$eta_lambda_hat)
eta_tau <- as.numeric(state$eta_tau_hat)
eta_c2 <- as.numeric(state$eta_c_hat)
f0 <- .static_rhs_obj_eta(eta_lam, eta_tau, eta_c2, beta2, ctrl)
hess <- .static_rhs_hess_active(eta_lam, eta_tau, eta_c2, beta2, ctrl)
act <- c(hess$idx, p + 1L, p + 2L)
Sigma_full <- state$Sigma_full
if (is.null(Sigma_full) || !all(dim(Sigma_full) == c(p + 2L, p + 2L))) {
Sigma_full <- diag(pmax(state$Sigma_diag, ctrl$var_floor), p + 2L)
}
Sigma_act <- Sigma_full[act, act, drop = FALSE]
trHS <- sum(hess$H * Sigma_act)
ld <- .static_rhs_inv_spd_with_jitter(Sigma_act, ctrl$var_floor)$logdet
H_qeta <- 0.5 * (nrow(Sigma_act) * (1 + log(2 * pi)) + ld)
f0 + 0.5 * trHS + H_qeta
}
.static_rhs_prec_mcmc <- function(state, ctrl) {
p <- state$p
invV <- 1 / (state$tau^2 * state$lambda^2) + 1 / state$c2
invV <- pmax(invV, 1e-16)
if (!isTRUE(ctrl$shrink_intercept) && p >= 1L) invV[1] <- ctrl$intercept_prec
invV
}
.static_rhs_logtarget_eta <- function(eta_lambda, eta_tau, eta_c2, beta2, ctrl) {
.static_rhs_obj_eta(eta_lambda, eta_tau, eta_c2, beta2, ctrl)
}
.static_rhs_update_mcmc <- function(state, beta, ctrl, slice_width = 1, slice_max_steps = 20) {
beta2 <- as.numeric(beta)^2
eta_lam <- log(pmax(state$lambda, 1e-16))
eta_tau <- log(pmax(state$tau, 1e-16))
eta_c2 <- log(pmax(state$c2, 1e-16))
if (!is.null(ctrl$zeta2_fixed)) eta_c2 <- log(as.numeric(ctrl$zeta2_fixed)[1])
for (j in .static_rhs_active_idx(length(beta2), ctrl$shrink_intercept)) {
log_density_j <- function(eta_j) {
et <- eta_lam
et[j] <- eta_j
.static_rhs_logtarget_eta(et, eta_tau, eta_c2, beta2, ctrl)
}
eta_lam[j] <- .exdqlm_uni_slice_bounded(
x0 = eta_lam[j],
log_density = log_density_j,
w = slice_width,
m = slice_max_steps,
lower = ctrl$eta_bounds$lambda[1],
upper = ctrl$eta_bounds$lambda[2]
)$value
}
eta_tau <- .exdqlm_uni_slice_bounded(
x0 = eta_tau,
log_density = function(etau) .static_rhs_logtarget_eta(eta_lam, etau, eta_c2, beta2, ctrl),
w = slice_width,
m = slice_max_steps,
lower = ctrl$eta_bounds$tau[1],
upper = ctrl$eta_bounds$tau[2]
)$value
if (is.null(ctrl$zeta2_fixed)) {
eta_c2 <- .exdqlm_uni_slice_bounded(
x0 = eta_c2,
log_density = function(ec) .static_rhs_logtarget_eta(eta_lam, eta_tau, ec, beta2, ctrl),
w = slice_width,
m = slice_max_steps,
lower = ctrl$eta_bounds$c2[1],
upper = ctrl$eta_bounds$c2[2]
)$value
} else {
eta_c2 <- log(as.numeric(ctrl$zeta2_fixed)[1])
}
state$lambda <- .static_rhs_safe_exp(eta_lam)
state$tau <- .static_rhs_safe_exp(eta_tau)
state$c2 <- if (!is.null(ctrl$zeta2_fixed)) as.numeric(ctrl$zeta2_fixed)[1] else .static_rhs_safe_exp(eta_c2)
state
}
.static_rhs_ig_entropy <- function(a, b) {
a <- pmax(as.numeric(a), 1e-12)
b <- pmax(as.numeric(b), 1e-12)
a + log(b) + lgamma(a) - (a + 1) * digamma(a)
}
# RHS-NS closed-form hierarchy used by static VB/MCMC:
# lambda_j^2 | nu_j ~ IG(1/2, 1/nu_j), nu_j ~ IG(1/2, 1)
# tau^2 | xi ~ IG(1/2, 1/xi), xi ~ IG(1/2, 1/tau0^2)
# zeta^2 ~ IG(a_zeta, b_zeta) (or fixed via zeta2_fixed).
# The induced coefficient precision is
# invV_j = 1 / (tau^2 * lambda_j^2) + 1 / zeta^2
# with intercept handled separately when shrink_intercept = FALSE.
.static_rhs_ns_recompute_moments <- function(state, ctrl) {
floor <- max(as.numeric(ctrl$var_floor)[1], 1e-16)
p <- as.integer(state$p)
state$lambda2 <- pmax(as.numeric(state$lambda2), floor)
state$nu <- pmax(as.numeric(state$nu), floor)
state$tau2 <- max(as.numeric(state$tau2)[1], floor)
state$xi <- max(as.numeric(state$xi)[1], floor)
if (isTRUE(state$zeta2_is_fixed)) {
state$zeta2 <- max(as.numeric(state$zeta2_fixed)[1], floor)
} else {
state$zeta2 <- max(as.numeric(state$zeta2)[1], floor)
}
if (!is.null(state$a_lambda) && !is.null(state$b_lambda)) {
a_lambda <- pmax(as.numeric(state$a_lambda), floor)
b_lambda <- pmax(as.numeric(state$b_lambda), floor)
if (length(a_lambda) == p && length(b_lambda) == p) {
state$E_inv_lambda2 <- pmax(a_lambda / b_lambda, floor)
} else {
state$E_inv_lambda2 <- 1 / state$lambda2
}
} else {
state$E_inv_lambda2 <- 1 / state$lambda2
}
if (!is.null(state$a_nu) && !is.null(state$b_nu)) {
a_nu <- pmax(as.numeric(state$a_nu), floor)
b_nu <- pmax(as.numeric(state$b_nu), floor)
if (length(a_nu) == p && length(b_nu) == p) {
state$E_inv_nu <- pmax(a_nu / b_nu, floor)
} else {
state$E_inv_nu <- 1 / state$nu
}
} else {
state$E_inv_nu <- 1 / state$nu
}
if (!is.null(state$a_tau) && !is.null(state$b_tau)) {
a_tau <- max(as.numeric(state$a_tau)[1], floor)
b_tau <- max(as.numeric(state$b_tau)[1], floor)
state$E_inv_tau2 <- max(a_tau / b_tau, floor)
} else {
state$E_inv_tau2 <- 1 / state$tau2
}
if (!is.null(state$a_xi) && !is.null(state$b_xi)) {
a_xi <- max(as.numeric(state$a_xi)[1], floor)
b_xi <- max(as.numeric(state$b_xi)[1], floor)
state$E_inv_xi <- max(a_xi / b_xi, floor)
} else {
state$E_inv_xi <- 1 / state$xi
}
if (isTRUE(state$zeta2_is_fixed)) {
state$E_inv_zeta2 <- 1 / state$zeta2
} else if (!is.null(state$a_zeta) && !is.null(state$b_zeta) &&
is.finite(state$a_zeta) && is.finite(state$b_zeta)) {
a_z <- max(as.numeric(state$a_zeta)[1], floor)
b_z <- max(as.numeric(state$b_zeta)[1], floor)
state$E_inv_zeta2 <- max(a_z / b_z, floor)
} else {
state$E_inv_zeta2 <- 1 / state$zeta2
}
state$lambda <- sqrt(pmax(state$lambda2, floor))
state$tau <- sqrt(max(state$tau2, floor))
state$c2 <- state$zeta2
state
}
.static_rhs_ns_init_vb_state <- function(p, ctrl) {
floor <- max(as.numeric(ctrl$var_floor)[1], 1e-16)
lam0 <- ctrl$init_lambda
if (length(lam0) == 1L) lam0 <- rep(lam0, p)
lam0 <- pmax(as.numeric(lam0), floor)
if (length(lam0) != p) stop("beta_prior_controls$init_lambda must be scalar or length p.")
tau0 <- pmax(as.numeric(ctrl$init_tau)[1], floor)
zeta0 <- if (!is.null(ctrl$zeta2_fixed)) {
pmax(as.numeric(ctrl$zeta2_fixed)[1], floor)
} else if (!is.null(ctrl$init_c2)) {
pmax(as.numeric(ctrl$init_c2)[1], floor)
} else {
pmax(as.numeric(ctrl$s2)[1], floor)
}
idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
m_active <- length(idx)
a_tau <- max((m_active + 1) / 2, floor)
a_xi <- 1
xi0 <- 1
state <- list(
p = p,
prior_type = "rhs_ns",
shrink_intercept = ctrl$shrink_intercept,
intercept_prec = ctrl$intercept_prec,
zeta2_is_fixed = !is.null(ctrl$zeta2_fixed),
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
lambda2 = pmax(lam0^2, floor),
nu = rep(1, p),
tau2 = tau0^2,
xi = xi0,
zeta2 = zeta0,
a_lambda = rep(1, p),
b_lambda = rep(1, p) / pmax(lam0^2, floor),
a_nu = rep(1, p),
b_nu = rep(1, p),
a_tau = a_tau,
b_tau = a_tau / pmax(tau0^2, floor),
a_xi = a_xi,
b_xi = a_xi / xi0,
a_zeta = if (!is.null(ctrl$zeta2_fixed)) NA_real_ else as.numeric(ctrl$a_zeta)[1],
b_zeta = if (!is.null(ctrl$zeta2_fixed)) NA_real_ else as.numeric(ctrl$b_zeta)[1],
iter = 0L,
freeze_tau = FALSE,
update_tau_only = FALSE,
tau_update_count = 0L,
has_post_warmup_tau_update = FALSE,
last_schedule = list(),
collapse_diag = NULL
)
.static_rhs_ns_recompute_moments(state, ctrl)
}
.static_rhs_ns_init_mcmc_state <- function(p, ctrl) {
.static_rhs_ns_init_vb_state(p, ctrl)
}
.static_rhs_ns_expected_prec_vb <- function(state, ctrl) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
p <- as.integer(state$p)
idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
prec <- rep(as.numeric(state$intercept_prec)[1], p)
if (length(idx)) {
prec[idx] <- state$E_inv_tau2 * state$E_inv_lambda2[idx] + state$E_inv_zeta2
} else if (isTRUE(ctrl$shrink_intercept) && p > 0L) {
prec <- state$E_inv_tau2 * state$E_inv_lambda2 + state$E_inv_zeta2
}
pmax(as.numeric(prec), 1e-16)
}
.static_rhs_ns_prec_mcmc <- function(state, ctrl) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
p <- as.integer(state$p)
invV <- pmax(1 / (state$tau2 * state$lambda2) + 1 / state$zeta2, 1e-16)
if (!isTRUE(ctrl$shrink_intercept) && p >= 1L) invV[1L] <- ctrl$intercept_prec
invV
}
.static_rhs_ns_collapse_diag_vb <- function(state, qbeta, ctrl) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
beta_use <- if (length(idx)) as.numeric(qbeta$m)[idx] else numeric(0)
tau <- sqrt(pmax(state$tau2, 1e-16))
log_tau <- log(tau)
tau0 <- max(as.numeric(ctrl$tau0)[1], 1e-16)
tau_ratio <- tau / tau0
slope_l2 <- if (length(beta_use)) sqrt(sum(beta_use^2)) else NA_real_
slope_max_abs <- if (length(beta_use)) max(abs(beta_use)) else NA_real_
small_beta_abs_tol <- as.numeric(ctrl$small_beta_abs_tol)[1]
small_beta_frac <- if (length(beta_use)) mean(abs(beta_use) <= small_beta_abs_tol) else NA_real_
E_invV <- .static_rhs_ns_expected_prec_vb(state, ctrl)
E_invV_use <- if (length(idx)) as.numeric(E_invV)[idx] else numeric(0)
E_invV_med <- if (length(E_invV_use) && any(is.finite(E_invV_use))) stats::median(E_invV_use[is.finite(E_invV_use)]) else NA_real_
tau_near_zero <- isTRUE(is.finite(tau_ratio) && tau_ratio <= ctrl$collapse_tau_ratio_tol) ||
isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6))
slope_collapse <- isTRUE(is.finite(slope_max_abs) && slope_max_abs <= ctrl$collapse_beta_max_abs_tol)
precision_beta_pattern <- isTRUE(
is.finite(E_invV_med) && E_invV_med >= ctrl$collapse_invV_med_tol &&
is.finite(slope_l2) && slope_l2 <= ctrl$collapse_beta_l2_tol &&
is.finite(small_beta_frac) && small_beta_frac >= ctrl$collapse_small_beta_frac_tol
)
collapse_flag <- isTRUE((tau_near_zero && slope_collapse) || precision_beta_pattern)
warning_msg <- if (collapse_flag) {
if (isTRUE(precision_beta_pattern) && !isTRUE(tau_near_zero && slope_collapse)) {
paste(
"RHS shrinkage-collapse detected from precision/beta pattern",
"(large E_invV + tiny beta norm + high near-zero-beta fraction).",
"Consider revising tau initialization and RHS controls."
)
} else {
paste(
"RHS global scale collapsed near zero and active coefficients collapsed toward zero.",
"Consider a larger tau0 and/or RHS tau warmup/freeze tuning."
)
}
} else {
NA_character_
}
list(
collapse_flag = collapse_flag,
precision_beta_pattern = precision_beta_pattern,
tau_near_zero = tau_near_zero,
slope_collapse = slope_collapse,
beta_collapse = slope_collapse,
tau = tau,
log_tau = log_tau,
tau0 = tau0,
tau_ratio = tau_ratio,
E_invV_med = E_invV_med,
slope_l2 = slope_l2,
beta_l2 = slope_l2,
slope_max_abs = slope_max_abs,
small_beta_frac = small_beta_frac,
small_beta_abs_tol = small_beta_abs_tol,
active_count = length(idx),
eta_tau_lower = ctrl$eta_bounds$tau[1],
at_tau_lower_bound = isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6)),
warning = warning_msg
)
}
.static_rhs_ns_collapse_diag_mcmc <- function(state, beta, ctrl) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
beta_use <- if (length(idx)) as.numeric(beta)[idx] else numeric(0)
tau <- sqrt(pmax(state$tau2, 1e-16))
log_tau <- log(tau)
tau0 <- max(as.numeric(ctrl$tau0)[1], 1e-16)
tau_ratio <- tau / tau0
slope_l2 <- if (length(beta_use)) sqrt(sum(beta_use^2)) else NA_real_
slope_max_abs <- if (length(beta_use)) max(abs(beta_use)) else NA_real_
small_beta_abs_tol <- as.numeric(ctrl$small_beta_abs_tol)[1]
small_beta_frac <- if (length(beta_use)) mean(abs(beta_use) <= small_beta_abs_tol) else NA_real_
invV <- .static_rhs_ns_prec_mcmc(state, ctrl)
invV_use <- if (length(idx)) as.numeric(invV)[idx] else numeric(0)
E_invV_med <- if (length(invV_use) && any(is.finite(invV_use))) stats::median(invV_use[is.finite(invV_use)]) else NA_real_
tau_near_zero <- isTRUE(is.finite(tau_ratio) && tau_ratio <= ctrl$collapse_tau_ratio_tol) ||
isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6))
slope_collapse <- isTRUE(is.finite(slope_max_abs) && slope_max_abs <= ctrl$collapse_beta_max_abs_tol)
precision_beta_pattern <- isTRUE(
is.finite(E_invV_med) && E_invV_med >= ctrl$collapse_invV_med_tol &&
is.finite(slope_l2) && slope_l2 <= ctrl$collapse_beta_l2_tol &&
is.finite(small_beta_frac) && small_beta_frac >= ctrl$collapse_small_beta_frac_tol
)
collapse_flag <- isTRUE((tau_near_zero && slope_collapse) || precision_beta_pattern)
warning_msg <- if (collapse_flag) {
if (isTRUE(precision_beta_pattern) && !isTRUE(tau_near_zero && slope_collapse)) {
paste(
"RHS shrinkage-collapse detected from precision/beta pattern",
"(large E_invV + tiny beta norm + high near-zero-beta fraction).",
"Consider revising tau initialization and RHS controls."
)
} else {
paste(
"RHS global scale collapsed near zero and active coefficients collapsed toward zero.",
"Consider a larger tau0 and/or RHS tau warmup/freeze tuning."
)
}
} else {
NA_character_
}
list(
collapse_flag = collapse_flag,
precision_beta_pattern = precision_beta_pattern,
tau_near_zero = tau_near_zero,
slope_collapse = slope_collapse,
beta_collapse = slope_collapse,
tau = tau,
log_tau = log_tau,
tau0 = tau0,
tau_ratio = tau_ratio,
E_invV_med = E_invV_med,
slope_l2 = slope_l2,
beta_l2 = slope_l2,
slope_max_abs = slope_max_abs,
small_beta_frac = small_beta_frac,
small_beta_abs_tol = small_beta_abs_tol,
active_count = length(idx),
eta_tau_lower = ctrl$eta_bounds$tau[1],
at_tau_lower_bound = isTRUE(log_tau <= (ctrl$eta_bounds$tau[1] + 1e-6)),
warning = warning_msg
)
}
.static_rhs_ns_update_vb <- function(state, qbeta, ctrl) {
floor <- max(as.numeric(ctrl$var_floor)[1], 1e-16)
state <- .static_rhs_ns_recompute_moments(state, ctrl)
p <- as.integer(state$p)
beta2 <- as.numeric(qbeta$m^2 + diag(qbeta$V))
active_idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
m_active <- length(active_idx)
iter_now <- as.integer(.static_prior_or(state$iter, 0L)) + 1L
sched <- .static_rhs_vb_schedule(state, ctrl, iter_now)
tau_updated <- FALSE
a_lambda <- pmax(as.numeric(state$a_lambda), floor)
b_lambda <- pmax(as.numeric(state$b_lambda), floor)
a_nu <- pmax(as.numeric(state$a_nu), floor)
b_nu <- pmax(as.numeric(state$b_nu), floor)
a_tau <- max((m_active + 1) / 2, floor)
b_tau <- max(as.numeric(state$b_tau)[1], floor)
a_xi <- 1
b_xi <- max(as.numeric(state$b_xi)[1], floor)
if (isTRUE(state$zeta2_is_fixed)) {
a_zeta <- NA_real_
b_zeta <- NA_real_
} else {
a_zeta <- max(as.numeric(.static_prior_or(state$a_zeta, ctrl$a_zeta))[1], floor)
b_zeta <- max(as.numeric(.static_prior_or(state$b_zeta, ctrl$b_zeta))[1], floor)
}
if (isTRUE(sched$do_update) && m_active > 0L) {
update_tau_only <- isTRUE(sched$update_tau_only)
for (inner in seq_len(ctrl$n_inner)) {
if (!isTRUE(update_tau_only)) {
e_inv_tau <- a_tau / b_tau
e_inv_nu <- a_nu / b_nu
b_lambda[active_idx] <- pmax(0.5 * beta2[active_idx] * e_inv_tau + e_inv_nu[active_idx], floor)
e_inv_lambda <- a_lambda / b_lambda
b_nu[active_idx] <- pmax(1 + e_inv_lambda[active_idx], floor)
if (!isTRUE(state$zeta2_is_fixed)) {
a_zeta <- max(as.numeric(ctrl$a_zeta)[1] + m_active / 2, floor)
b_zeta <- pmax(as.numeric(ctrl$b_zeta)[1] + 0.5 * sum(beta2[active_idx]), floor)
}
}
if (!isTRUE(sched$tau_warmup)) {
e_inv_lambda <- a_lambda / b_lambda
e_inv_xi <- a_xi / b_xi
b_tau <- pmax(0.5 * sum(beta2[active_idx] * e_inv_lambda[active_idx]) + e_inv_xi, floor)
e_inv_tau <- a_tau / b_tau
b_xi <- pmax((1 / (ctrl$tau0^2)) + e_inv_tau, floor)
tau_updated <- TRUE
}
}
} else if (!isTRUE(state$zeta2_is_fixed) && m_active > 0L && isTRUE(sched$do_update)) {
a_zeta <- max(as.numeric(ctrl$a_zeta)[1] + m_active / 2, floor)
b_zeta <- pmax(as.numeric(ctrl$b_zeta)[1] + 0.5 * sum(beta2[active_idx]), floor)
}
state$a_lambda <- a_lambda
state$b_lambda <- b_lambda
state$a_nu <- a_nu
state$b_nu <- b_nu
state$a_tau <- a_tau
state$b_tau <- b_tau
state$a_xi <- a_xi
state$b_xi <- b_xi
state$a_zeta <- if (isTRUE(state$zeta2_is_fixed)) NA_real_ else a_zeta
state$b_zeta <- if (isTRUE(state$zeta2_is_fixed)) NA_real_ else b_zeta
state$E_inv_lambda2 <- pmax(a_lambda / b_lambda, floor)
state$E_inv_nu <- pmax(a_nu / b_nu, floor)
state$E_inv_tau2 <- max(a_tau / b_tau, floor)
state$E_inv_xi <- max(a_xi / b_xi, floor)
if (isTRUE(state$zeta2_is_fixed)) {
state$E_inv_zeta2 <- 1 / max(as.numeric(state$zeta2_fixed)[1], floor)
state$zeta2 <- max(as.numeric(state$zeta2_fixed)[1], floor)
} else {
state$E_inv_zeta2 <- max(a_zeta / b_zeta, floor)
state$zeta2 <- 1 / state$E_inv_zeta2
}
state$lambda2 <- 1 / pmax(state$E_inv_lambda2, floor)
state$nu <- 1 / pmax(state$E_inv_nu, floor)
state$tau2 <- 1 / max(state$E_inv_tau2, floor)
state$xi <- 1 / max(state$E_inv_xi, floor)
state <- .static_rhs_ns_recompute_moments(state, ctrl)
state$iter <- iter_now
state$freeze_tau <- isTRUE(sched$tau_warmup)
state$update_tau_only <- isTRUE(sched$update_tau_only)
state$tau_update_count <- as.integer(.static_prior_or(state$tau_update_count, 0L)) + if (tau_updated) 1L else 0L
state$has_post_warmup_tau_update <- isTRUE(.static_prior_or(state$has_post_warmup_tau_update, FALSE) || tau_updated)
state$last_schedule <- c(
sched,
list(
tau_updated = tau_updated,
tau_update_count = state$tau_update_count
)
)
state$collapse_diag <- .static_rhs_ns_collapse_diag_vb(state, qbeta, ctrl)
state
}
.static_rhs_ns_elbo_vb <- function(state, qbeta, ctrl) {
floor <- max(as.numeric(ctrl$var_floor)[1], 1e-16)
state <- .static_rhs_ns_recompute_moments(state, ctrl)
p <- as.integer(state$p)
beta2 <- as.numeric(qbeta$m^2 + diag(qbeta$V))
active_idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
a_lambda <- pmax(as.numeric(state$a_lambda), floor)
b_lambda <- pmax(as.numeric(state$b_lambda), floor)
a_nu <- pmax(as.numeric(state$a_nu), floor)
b_nu <- pmax(as.numeric(state$b_nu), floor)
a_tau <- max(as.numeric(state$a_tau)[1], floor)
b_tau <- max(as.numeric(state$b_tau)[1], floor)
a_xi <- max(as.numeric(state$a_xi)[1], floor)
b_xi <- max(as.numeric(state$b_xi)[1], floor)
e_log_lambda <- log(b_lambda) - digamma(a_lambda)
e_log_nu <- log(b_nu) - digamma(a_nu)
e_inv_lambda <- a_lambda / b_lambda
e_inv_nu <- a_nu / b_nu
e_log_tau <- log(b_tau) - digamma(a_tau)
e_inv_tau <- a_tau / b_tau
e_log_xi <- log(b_xi) - digamma(a_xi)
e_inv_xi <- a_xi / b_xi
if (isTRUE(state$zeta2_is_fixed)) {
e_log_zeta <- log(pmax(as.numeric(state$zeta2_fixed)[1], floor))
e_inv_zeta <- 1 / pmax(as.numeric(state$zeta2_fixed)[1], floor)
e_log_p_zeta <- 0
h_zeta <- 0
} else {
a_zeta <- max(as.numeric(state$a_zeta)[1], floor)
b_zeta <- max(as.numeric(state$b_zeta)[1], floor)
e_log_zeta <- log(b_zeta) - digamma(a_zeta)
e_inv_zeta <- a_zeta / b_zeta
e_log_p_zeta <- as.numeric(ctrl$a_zeta)[1] * log(as.numeric(ctrl$b_zeta)[1]) -
lgamma(as.numeric(ctrl$a_zeta)[1]) -
(as.numeric(ctrl$a_zeta)[1] + 1) * e_log_zeta -
as.numeric(ctrl$b_zeta)[1] * e_inv_zeta
h_zeta <- .static_rhs_ig_entropy(a_zeta, b_zeta)
}
if (length(active_idx)) {
k_half <- 0.5
log2pi_half <- -0.5 * log(2 * pi)
e_log_p_beta_hs <- sum(
log2pi_half -
0.5 * e_log_tau -
0.5 * e_log_lambda[active_idx] -
0.5 * beta2[active_idx] * e_inv_tau * e_inv_lambda[active_idx]
)
e_log_p_beta_slab <- sum(
log2pi_half -
0.5 * e_log_zeta -
0.5 * beta2[active_idx] * e_inv_zeta
)
e_log_p_lambda_given_nu <- sum(
k_half * (-e_log_nu[active_idx]) -
lgamma(k_half) -
(k_half + 1) * e_log_lambda[active_idx] -
e_inv_nu[active_idx] * e_inv_lambda[active_idx]
)
e_log_p_nu <- sum(
-lgamma(k_half) -
(k_half + 1) * e_log_nu[active_idx] -
e_inv_nu[active_idx]
)
e_log_p_tau_given_xi <- (
k_half * (-e_log_xi) -
lgamma(k_half) -
(k_half + 1) * e_log_tau -
e_inv_xi * e_inv_tau
)
e_log_p_xi <- (
k_half * log(1 / (ctrl$tau0^2)) -
lgamma(k_half) -
(k_half + 1) * e_log_xi -
(1 / (ctrl$tau0^2)) * e_inv_xi
)
e_log_joint <- e_log_p_beta_hs + e_log_p_beta_slab +
e_log_p_lambda_given_nu + e_log_p_nu +
e_log_p_tau_given_xi + e_log_p_xi + e_log_p_zeta
h_latent <- sum(.static_rhs_ig_entropy(a_lambda[active_idx], b_lambda[active_idx])) +
sum(.static_rhs_ig_entropy(a_nu[active_idx], b_nu[active_idx])) +
.static_rhs_ig_entropy(a_tau, b_tau) +
.static_rhs_ig_entropy(a_xi, b_xi) +
h_zeta
} else {
e_log_joint <- e_log_p_zeta
h_latent <- h_zeta
}
e_log_intercept <- 0
if (!isTRUE(state$shrink_intercept) && p >= 1L) {
prec0 <- as.numeric(state$intercept_prec)[1]
prec0 <- if (is.finite(prec0) && prec0 > 0) prec0 else 1e-16
e_log_intercept <- 0.5 * (log(prec0) - log(2 * pi)) - 0.5 * prec0 * beta2[1L]
}
as.numeric(e_log_joint + h_latent + e_log_intercept)
}
.static_rhs_ns_update_mcmc <- function(state, beta, ctrl) {
floor <- max(as.numeric(ctrl$var_floor)[1], 1e-16)
state <- .static_rhs_ns_recompute_moments(state, ctrl)
beta <- as.numeric(beta)
p <- as.integer(state$p)
if (length(beta) != p) stop("rhs_ns mcmc update: beta length mismatch.")
active_idx <- .static_rhs_active_idx(p, ctrl$shrink_intercept)
m_active <- length(active_idx)
beta2 <- beta^2
lambda2 <- pmax(as.numeric(state$lambda2), floor)
nu <- pmax(as.numeric(state$nu), floor)
tau2 <- max(as.numeric(state$tau2)[1], floor)
xi <- max(as.numeric(state$xi)[1], floor)
zeta2 <- max(as.numeric(state$zeta2)[1], floor)
iter_now <- as.integer(.static_prior_or(state$iter, 0L)) + 1L
tau_warmup <- isTRUE(ctrl$freeze_tau_warmup_iters > 0L && iter_now <= ctrl$freeze_tau_warmup_iters)
tau_updated <- FALSE
if (m_active > 0L) {
for (j in active_idx) {
rate_lambda <- max(1 / nu[j] + 0.5 * beta2[j] / tau2, floor)
lambda2[j] <- 1 / stats::rgamma(1L, shape = 1, rate = rate_lambda)
lambda2[j] <- max(lambda2[j], floor)
rate_nu <- max(1 + 1 / lambda2[j], floor)
nu[j] <- 1 / stats::rgamma(1L, shape = 1, rate = rate_nu)
nu[j] <- max(nu[j], floor)
}
if (!isTRUE(tau_warmup)) {
shape_tau <- (m_active + 1) / 2
rate_tau <- max(1 / xi + 0.5 * sum(beta2[active_idx] / lambda2[active_idx]), floor)
tau2 <- 1 / stats::rgamma(1L, shape = shape_tau, rate = rate_tau)
tau2 <- max(tau2, floor)
rate_xi <- max(1 / (ctrl$tau0^2) + 1 / tau2, floor)
xi <- 1 / stats::rgamma(1L, shape = 1, rate = rate_xi)
xi <- max(xi, floor)
tau_updated <- TRUE
}
if (!isTRUE(state$zeta2_is_fixed)) {
shape_zeta <- max(as.numeric(ctrl$a_zeta)[1] + m_active / 2, floor)
rate_zeta <- max(as.numeric(ctrl$b_zeta)[1] + 0.5 * sum(beta2[active_idx]), floor)
zeta2 <- 1 / stats::rgamma(1L, shape = shape_zeta, rate = rate_zeta)
zeta2 <- max(zeta2, floor)
} else {
zeta2 <- max(as.numeric(state$zeta2_fixed)[1], floor)
}
}
state$lambda2 <- lambda2
state$nu <- nu
state$tau2 <- tau2
state$xi <- xi
state$zeta2 <- zeta2
state$a_lambda <- rep(1, p)
state$b_lambda <- 1 / pmax(lambda2, floor)
state$a_nu <- rep(1, p)
state$b_nu <- 1 / pmax(nu, floor)
state$a_tau <- max((m_active + 1) / 2, floor)
state$b_tau <- state$a_tau / pmax(tau2, floor)
state$a_xi <- 1
state$b_xi <- 1 / pmax(xi, floor)
if (isTRUE(state$zeta2_is_fixed)) {
state$a_zeta <- NA_real_
state$b_zeta <- NA_real_
} else {
state$a_zeta <- max(as.numeric(ctrl$a_zeta)[1] + m_active / 2, floor)
state$b_zeta <- state$a_zeta / pmax(zeta2, floor)
}
state <- .static_rhs_ns_recompute_moments(state, ctrl)
state$iter <- iter_now
state$freeze_tau <- isTRUE(tau_warmup)
state$update_tau_only <- FALSE
state$tau_update_count <- as.integer(.static_prior_or(state$tau_update_count, 0L)) + if (tau_updated) 1L else 0L
state$has_post_warmup_tau_update <- isTRUE(.static_prior_or(state$has_post_warmup_tau_update, FALSE) || tau_updated)
state$last_schedule <- list(
iter = iter_now,
tau_warmup = tau_warmup,
reason = if (tau_warmup) "warmup" else "scheduled",
tau_updated = tau_updated,
tau_update_count = state$tau_update_count
)
state$collapse_diag <- .static_rhs_ns_collapse_diag_mcmc(state, beta, ctrl)
state
}
.static_rhs_ns_summary_vb <- function(state, ctrl) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
lam <- state$lambda[idx]
collapse_diag <- .static_prior_or(state$collapse_diag, list())
list(
tau = state$tau,
log_tau = log(pmax(state$tau, 1e-16)),
c2 = state$c2,
zeta2 = state$zeta2,
tau2 = state$tau2,
xi = state$xi,
lambda = lam,
lambda2 = state$lambda2[idx],
nu = state$nu[idx],
lambda_mean = if (length(lam)) mean(lam) else NA_real_,
lambda_min = if (length(lam)) min(lam) else NA_real_,
lambda_max = if (length(lam)) max(lam) else NA_real_,
shrink_intercept = ctrl$shrink_intercept,
tau0 = ctrl$tau0,
nu_hyper = ctrl$nu,
s = ctrl$s,
s2 = ctrl$s2,
a_zeta = as.numeric(.static_prior_or(ctrl$a_zeta, NA_real_))[1],
b_zeta = as.numeric(.static_prior_or(ctrl$b_zeta, NA_real_))[1],
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
rhs_iter = as.integer(.static_prior_or(state$iter, NA_integer_)),
rhs_tau_update_count = as.integer(.static_prior_or(state$tau_update_count, NA_integer_)),
rhs_tau_warmup_last = isTRUE(.static_prior_or(state$last_schedule$tau_warmup, FALSE)),
rhs_update_reason_last = as.character(.static_prior_or(state$last_schedule$reason, NA_character_))[1],
rhs_update_every_last = as.integer(.static_prior_or(state$last_schedule$update_every_eff, NA_integer_)),
collapse_flag = isTRUE(collapse_diag$collapse_flag),
collapse_pattern = isTRUE(collapse_diag$precision_beta_pattern),
collapse_tau_near_zero = isTRUE(collapse_diag$tau_near_zero),
collapse_beta = isTRUE(collapse_diag$slope_collapse),
collapse_tau_ratio = as.numeric(collapse_diag$tau_ratio)[1],
collapse_E_invV_med = as.numeric(collapse_diag$E_invV_med)[1],
collapse_beta_l2 = as.numeric(collapse_diag$beta_l2)[1],
collapse_small_beta_frac = as.numeric(collapse_diag$small_beta_frac)[1],
collapse_small_beta_abs_tol = as.numeric(collapse_diag$small_beta_abs_tol)[1],
collapse_slope_l2 = as.numeric(collapse_diag$slope_l2)[1],
collapse_slope_max_abs = as.numeric(collapse_diag$slope_max_abs)[1],
collapse_warning = as.character(.static_prior_or(collapse_diag$warning, NA_character_))[1],
init_log_tau_resolved = as.numeric(ctrl$init_log_tau)[1],
init_tau_resolved = as.numeric(ctrl$init_tau)[1],
init_tau_source = as.character(ctrl$init_tau_source)[1],
eta_tau_lower = as.numeric(ctrl$eta_bounds$tau[1]),
eta_tau_upper = as.numeric(ctrl$eta_bounds$tau[2])
)
}
.static_rhs_ns_summary_mcmc <- function(state, ctrl, beta = NULL) {
state <- .static_rhs_ns_recompute_moments(state, ctrl)
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
lam <- state$lambda[idx]
collapse_diag <- if (!is.null(beta)) .static_rhs_ns_collapse_diag_mcmc(state, beta, ctrl) else list()
list(
tau = state$tau,
log_tau = log(pmax(state$tau, 1e-16)),
c2 = state$c2,
zeta2 = state$zeta2,
tau2 = state$tau2,
xi = state$xi,
lambda = lam,
lambda2 = state$lambda2[idx],
nu = state$nu[idx],
lambda_mean = if (length(lam)) mean(lam) else NA_real_,
lambda_min = if (length(lam)) min(lam) else NA_real_,
lambda_max = if (length(lam)) max(lam) else NA_real_,
shrink_intercept = ctrl$shrink_intercept,
tau0 = ctrl$tau0,
nu_hyper = ctrl$nu,
s = ctrl$s,
s2 = ctrl$s2,
a_zeta = as.numeric(.static_prior_or(ctrl$a_zeta, NA_real_))[1],
b_zeta = as.numeric(.static_prior_or(ctrl$b_zeta, NA_real_))[1],
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
collapse_flag = if (!is.null(collapse_diag$collapse_flag)) isTRUE(collapse_diag$collapse_flag) else NA,
collapse_pattern = if (!is.null(collapse_diag$precision_beta_pattern)) isTRUE(collapse_diag$precision_beta_pattern) else NA,
collapse_tau_near_zero = if (!is.null(collapse_diag$tau_near_zero)) isTRUE(collapse_diag$tau_near_zero) else NA,
collapse_beta = if (!is.null(collapse_diag$slope_collapse)) isTRUE(collapse_diag$slope_collapse) else NA,
collapse_tau_ratio = as.numeric(.static_prior_or(collapse_diag$tau_ratio, NA_real_))[1],
collapse_E_invV_med = as.numeric(.static_prior_or(collapse_diag$E_invV_med, NA_real_))[1],
collapse_beta_l2 = as.numeric(.static_prior_or(collapse_diag$beta_l2, NA_real_))[1],
collapse_small_beta_frac = as.numeric(.static_prior_or(collapse_diag$small_beta_frac, NA_real_))[1],
collapse_small_beta_abs_tol = as.numeric(.static_prior_or(collapse_diag$small_beta_abs_tol, ctrl$small_beta_abs_tol))[1],
collapse_warning = as.character(.static_prior_or(collapse_diag$warning, NA_character_))[1],
init_log_tau_resolved = as.numeric(ctrl$init_log_tau)[1],
init_tau_resolved = as.numeric(ctrl$init_tau)[1],
init_tau_source = as.character(ctrl$init_tau_source)[1],
eta_tau_lower = as.numeric(ctrl$eta_bounds$tau[1]),
eta_tau_upper = as.numeric(ctrl$eta_bounds$tau[2])
)
}
.static_beta_prior_make <- function(beta_prior = c("ridge", "rhs", "rhs_ns"), p, b0, V0, beta_prior_controls = NULL,
warn_rhs_b0 = FALSE, warn_rhs_V0 = FALSE) {
beta_prior <- .static_match_beta_prior(beta_prior)
V0_inv <- tryCatch(solve(V0), error = function(e) solve(V0 + 1e-8 * diag(ncol(V0))))
logdetV0 <- as.numeric(determinant(V0, logarithm = TRUE)$modulus)
if (identical(beta_prior, "ridge")) {
return(list(
type = "ridge",
controls = NULL,
init_vb = function() list(),
init_mcmc = function() list(),
beta_system_vb = function(state) list(Prec = V0_inv, h = drop(V0_inv %*% b0)),
beta_system_mcmc = function(state) list(Prec = V0_inv, h = drop(V0_inv %*% b0)),
update_vb = function(state, qbeta) state,
update_mcmc = function(state, beta, ...) state,
elbo_vb = function(state, qbeta) {
- (p / 2) * log(2 * pi) - 0.5 * logdetV0 -
0.5 * (sum(V0_inv * qbeta$V) + drop(crossprod(qbeta$m - b0, V0_inv %*% (qbeta$m - b0))))
},
summary_vb = function(state) NULL,
summary_mcmc = function(state, beta = NULL) NULL
))
}
ctrl <- .static_parse_beta_prior_controls(beta_prior_controls, prior_type = beta_prior)
if (warn_rhs_b0 || warn_rhs_V0) {
warning(
sprintf(
"beta_prior = '%s' ignores b0/V0 for the shrunk coefficients; they are only retained for backward-compatible ridge behavior.",
beta_prior
),
call. = FALSE
)
}
if (identical(beta_prior, "rhs_ns")) {
return(list(
type = beta_prior,
controls = ctrl,
init_vb = function() .static_rhs_ns_init_vb_state(p, ctrl),
init_mcmc = function() .static_rhs_ns_init_mcmc_state(p, ctrl),
beta_system_vb = function(state) {
prec <- .static_rhs_ns_expected_prec_vb(state, ctrl)
list(Prec = diag(prec, p), h = rep(0, p), prec_diag = prec)
},
beta_system_mcmc = function(state) {
prec <- .static_rhs_ns_prec_mcmc(state, ctrl)
list(Prec = diag(prec, p), h = rep(0, p), prec_diag = prec)
},
update_vb = function(state, qbeta) .static_rhs_ns_update_vb(state, qbeta, ctrl),
update_mcmc = function(state, beta, ...) .static_rhs_ns_update_mcmc(state, beta, ctrl),
elbo_vb = function(state, qbeta) .static_rhs_ns_elbo_vb(state, qbeta, ctrl),
summary_vb = function(state) .static_rhs_ns_summary_vb(state, ctrl),
summary_mcmc = function(state, beta = NULL) .static_rhs_ns_summary_mcmc(state, ctrl, beta = beta)
))
}
list(
type = beta_prior,
controls = ctrl,
init_vb = function() .static_rhs_init_vb_state(p, ctrl),
init_mcmc = function() .static_rhs_init_mcmc_state(p, ctrl),
beta_system_vb = function(state) {
prec <- .static_rhs_expected_prec_vb(state, ctrl)
list(Prec = diag(prec, p), h = rep(0, p), prec_diag = prec)
},
beta_system_mcmc = function(state) {
prec <- .static_rhs_prec_mcmc(state, ctrl)
list(Prec = diag(prec, p), h = rep(0, p), prec_diag = prec)
},
update_vb = function(state, qbeta) .static_rhs_update_vb(state, qbeta, ctrl),
update_mcmc = function(state, beta, slice_width = NULL, slice_max_steps = NULL) {
.static_rhs_update_mcmc(
state,
beta,
ctrl,
slice_width = .static_prior_or(slice_width, ctrl$slice_width),
slice_max_steps = .static_prior_or(slice_max_steps, ctrl$slice_max_steps)
)
},
elbo_vb = function(state, qbeta) .static_rhs_elbo_vb(state, qbeta, ctrl),
summary_vb = function(state) {
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
lam <- .static_rhs_safe_exp(state$eta_lambda_hat[idx])
collapse_diag <- .static_prior_or(state$collapse_diag, list())
list(
tau = .static_rhs_safe_exp(state$eta_tau_hat),
log_tau = as.numeric(state$eta_tau_hat)[1],
c2 = .static_rhs_safe_exp(state$eta_c_hat),
zeta2 = .static_rhs_safe_exp(state$eta_c_hat),
lambda = lam,
lambda_mean = if (length(lam)) mean(lam) else NA_real_,
lambda_min = if (length(lam)) min(lam) else NA_real_,
lambda_max = if (length(lam)) max(lam) else NA_real_,
shrink_intercept = ctrl$shrink_intercept,
tau0 = ctrl$tau0,
nu = ctrl$nu,
s = ctrl$s,
s2 = ctrl$s2,
a_zeta = as.numeric(.static_prior_or(ctrl$a_zeta, NA_real_))[1],
b_zeta = as.numeric(.static_prior_or(ctrl$b_zeta, NA_real_))[1],
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
rhs_iter = as.integer(.static_prior_or(state$iter, NA_integer_)),
rhs_tau_update_count = as.integer(.static_prior_or(state$tau_update_count, NA_integer_)),
rhs_tau_warmup_last = isTRUE(.static_prior_or(state$last_schedule$tau_warmup, FALSE)),
rhs_update_reason_last = as.character(.static_prior_or(state$last_schedule$reason, NA_character_))[1],
rhs_update_every_last = as.integer(.static_prior_or(state$last_schedule$update_every_eff, NA_integer_)),
collapse_flag = isTRUE(collapse_diag$collapse_flag),
collapse_pattern = isTRUE(collapse_diag$precision_beta_pattern),
collapse_tau_near_zero = isTRUE(collapse_diag$tau_near_zero),
collapse_beta = isTRUE(collapse_diag$slope_collapse),
collapse_tau_ratio = as.numeric(collapse_diag$tau_ratio)[1],
collapse_E_invV_med = as.numeric(collapse_diag$E_invV_med)[1],
collapse_beta_l2 = as.numeric(collapse_diag$beta_l2)[1],
collapse_small_beta_frac = as.numeric(collapse_diag$small_beta_frac)[1],
collapse_small_beta_abs_tol = as.numeric(collapse_diag$small_beta_abs_tol)[1],
collapse_slope_l2 = as.numeric(collapse_diag$slope_l2)[1],
collapse_slope_max_abs = as.numeric(collapse_diag$slope_max_abs)[1],
collapse_warning = as.character(.static_prior_or(collapse_diag$warning, NA_character_))[1],
init_log_tau_resolved = as.numeric(ctrl$init_log_tau)[1],
init_tau_resolved = as.numeric(ctrl$init_tau)[1],
init_tau_source = as.character(ctrl$init_tau_source)[1],
eta_tau_lower = as.numeric(ctrl$eta_bounds$tau[1]),
eta_tau_upper = as.numeric(ctrl$eta_bounds$tau[2])
)
},
summary_mcmc = function(state, beta = NULL) {
idx <- .static_rhs_active_idx(state$p, ctrl$shrink_intercept)
lam <- state$lambda[idx]
collapse_diag <- if (!is.null(beta)) {
.static_rhs_collapse_diag_mcmc(state, beta, ctrl)
} else {
list()
}
list(
tau = state$tau,
log_tau = log(pmax(as.numeric(state$tau)[1], 1e-16)),
c2 = state$c2,
zeta2 = state$c2,
lambda = lam,
lambda_mean = if (length(lam)) mean(lam) else NA_real_,
lambda_min = if (length(lam)) min(lam) else NA_real_,
lambda_max = if (length(lam)) max(lam) else NA_real_,
shrink_intercept = ctrl$shrink_intercept,
tau0 = ctrl$tau0,
nu = ctrl$nu,
s = ctrl$s,
s2 = ctrl$s2,
a_zeta = as.numeric(.static_prior_or(ctrl$a_zeta, NA_real_))[1],
b_zeta = as.numeric(.static_prior_or(ctrl$b_zeta, NA_real_))[1],
zeta2_fixed = as.numeric(.static_prior_or(ctrl$zeta2_fixed, NA_real_))[1],
collapse_flag = if (!is.null(collapse_diag$collapse_flag)) isTRUE(collapse_diag$collapse_flag) else NA,
collapse_pattern = if (!is.null(collapse_diag$precision_beta_pattern)) isTRUE(collapse_diag$precision_beta_pattern) else NA,
collapse_tau_near_zero = if (!is.null(collapse_diag$tau_near_zero)) isTRUE(collapse_diag$tau_near_zero) else NA,
collapse_beta = if (!is.null(collapse_diag$slope_collapse)) isTRUE(collapse_diag$slope_collapse) else NA,
collapse_tau_ratio = as.numeric(.static_prior_or(collapse_diag$tau_ratio, NA_real_))[1],
collapse_E_invV_med = as.numeric(.static_prior_or(collapse_diag$E_invV_med, NA_real_))[1],
collapse_beta_l2 = as.numeric(.static_prior_or(collapse_diag$beta_l2, NA_real_))[1],
collapse_small_beta_frac = as.numeric(.static_prior_or(collapse_diag$small_beta_frac, NA_real_))[1],
collapse_small_beta_abs_tol = as.numeric(.static_prior_or(collapse_diag$small_beta_abs_tol, ctrl$small_beta_abs_tol))[1],
collapse_warning = as.character(.static_prior_or(collapse_diag$warning, NA_character_))[1],
init_log_tau_resolved = as.numeric(ctrl$init_log_tau)[1],
init_tau_resolved = as.numeric(ctrl$init_tau)[1],
init_tau_source = as.character(ctrl$init_tau_source)[1],
eta_tau_lower = as.numeric(ctrl$eta_bounds$tau[1]),
eta_tau_upper = as.numeric(ctrl$eta_bounds$tau[2])
)
}
)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.