R/outbreakHNB.R

Defines functions outbreakHNB

Documented in outbreakHNB

#' Fit an outbreak detection Hurdle negative binomial outbreak model
#'
#' @param cases            Integer or numeric vector of observed case counts (length N).
#' @param pop              (Optional) Numeric vector of population offsets (length N). If NULL, offset = 1.
#' @param covariates_count (Optional) Data.frame or matrix of covariates for the count model (N x p_c).
#' @param covariates_zero  (Optional) Data.frame or matrix of covariates for the Hurdle model (N x p_z).
#' @param beta_init        (Optional) List of length `n_chains` giving initial values for beta (each a vector of length p_c+1).
#' @param delta_init       (Optional) List of length `n_chains` giving initial values for delta (each a vector of length p_z+1).
#' @param r_init           (Optional) Numeric vector of length `n_chains` giving initial values for the NB dispersion parameter.
#' @param beta_prior_mean  Prior mean for beta coefficients of the Negative binomial part (default = 0).
#' @param beta_prior_sd    Prior SD   for beta coefficients of the Negative binomial part (default = 10).
#' @param delta_prior_mean Prior mean for delta coefficients of the Hurdle part (default = 0).
#' @param delta_prior_sd   Prior SD   for delta coefficients of the Hurdle part (default = 10).
#' @param r_prior_shape    Shape parameter of a prior on r (default = 1).
#' @param r_prior_rate     Rate  parameter of b prior on r (default = 1).
#' @param p_priors         Alpha parameters for the binomial priors on p00 and p11 (default = 1).
#' @param n_iter           Total number of MCMC iterations per chain (default = 100000).
#' @param n_burnin         Number of burn-in iterations (default = 10000).
#' @param n_chains         Number of MCMC chains (default = 3).
#' @param n_thin           Thinning interval for MCMC samples (default = 1).
#' @param save_params      Character vector of parameter names to save (must include "Z").
#' @param dates            (Optional) Vector of Date or POSIX dates for plotting Z; if NULL, uses index 1:N.
#' @param plot_Z           Logical; if TRUE, returns a ggplot2 object of the posterior mean Z over time.

#' @return A list with MCMC summary, samples, DIC, WAIC, and plot of the probability of being in an epidemic state.

#' @export
#' @examples
#' # ---- tiny example for users & CRAN (< 5s) ----
#' set.seed(13)
#' n <- 120
#' # baseline hurdle-like series: generate NB counts then zero out some days
#' base  <- rnbinom(n, size = 6, mu = 8)
#' zeros <- rbinom(n, 1, 0.30)
#' cases <- ifelse(zeros == 1, 0L, base)
#' # inject a small "outbreak" window
#' cases[70:74] <- cases[70:74] + rnbinom(5, size = 6, mu = 25)
#' dates <- as.Date("2020-01-01") + seq_len(n) - 1L
#'
#' \dontshow{
#' # checks that run on CRAN but are hidden from users
#' stopifnot(length(cases) == n, all(cases >= 0), inherits(dates, "Date"))
#' }
#'
#' # ---- actually run the detector, but only when JAGS is available ----
#' @examplesIf nzchar(Sys.which("jags")) && requireNamespace("R2jags", quietly = TRUE)
#' \donttest{
#' fit <- outbreakHNB(
#'   cases   = cases,
#'   dates   = dates,
#'   n_iter  = 10,   # keep fast for examples
#'   n_burnin= 1,
#'   n_chains= 1,
#'   n_thin  = 1,
#'   plot_Z  = FALSE  # avoid plotting in examples (rename/omit if not applicable)
#' )
#' print(fit)
#'}
#'
#' \donttest{
#' # ---- longer user-facing demo (skipped on checks) ----
#' # Increase iterations a bit for a stabler run (still JAGS-gated by @examplesIf above)
#' # fit2 <- outbreakHNB(
#' #   cases   = cases,
#' #   dates   = dates,
#' #   n_iter  = 1500,
#' #   n_burnin= 500,
#' #   n_chains= 2,
#' #   n_thin  = 2,
#' #   plot_Z  = FALSE
#' # )
#' # print(fit2)
#' }
#'
#' \dontrun{
#' # ---- time-consuming / full demo (not run anywhere) ----
#' # Here you might use larger MCMC and produce figures/tables of alerts.
#' # fit_full <- outbreakHNB(
#' #   cases   = cases,
#' #   dates   = dates,
#' #   n_iter  = 10000,
#' #   n_burnin= 5000,
#' #   n_chains= 4,
#' #   n_thin  = 1,
#' #   plot_Z  = TRUE
#' # )
#' # print(fit_full)
#' }
#'
#' if (interactive()) {
#'   # e.g., if a plot method exists:  # plot(fit)
#' }


outbreakHNB <- function(
    cases,
    pop = NULL,
    covariates_count = NULL,
    covariates_zero  = NULL,
    beta_init = NULL,
    delta_init = NULL,
    r_init = NULL,
    beta_prior_mean = 0,
    beta_prior_sd = 10,
    delta_prior_mean = 0,
    delta_prior_sd = 10,
    r_prior_shape = 1,
    r_prior_rate = 1,
    p_priors = 1,
    n_iter = 100000,
    n_burnin = 10000,
    n_chains = 3,
    n_thin = 1,
    save_params = c("beta", "delta", "r", "Z"),
    dates = NULL,      # optional date vector aligned with cases
    plot_Z = FALSE     # whether to build the Z plot
) {
  if (!requireNamespace("R2jags", quietly = TRUE)) stop("Package R2jags is required.")
  if (!requireNamespace("coda", quietly = TRUE)) stop("Package coda is required.")
  if (!requireNamespace("rjags", quietly = TRUE)) stop("Package rjags is required for WAIC/dev calculations.")

  N <- length(cases)

  # if save_params was explicitly passed, ensure Z is included
  if (!missing(save_params)) {
    if (!("Z" %in% save_params)) {
      save_params <- c(save_params, "Z")
    }
  }

  # Count covariate matrix with intercept
  if (!is.null(covariates_count)) {
    Xc1 <- as.matrix(covariates_count)
    if (nrow(Xc1) != N) stop("covariates_count must have length equal to cases.")
  } else {
    Xc1 <- matrix(0, nrow = N, ncol = 0)
  }
  Xc <- cbind(Intercept = 1, Xc1)
  Kc <- ncol(Xc)

  # Zero-inflation covariate matrix with intercept
  if (!is.null(covariates_zero)) {
    Xz1 <- as.matrix(covariates_zero)
    if (nrow(Xz1) != N) stop("covariates_zero must have length equal to cases.")
  } else {
    Xz1 <- matrix(0, nrow = N, ncol = 0)
  }
  Xz <- cbind(Intercept = 1, Xz1)
  Kz <- ncol(Xz)

  # Offsets (your original names)
  if (is.null(pop)) {
    pop_vec <- rep(1, N)
    off_str <- ""
    off1    <- ""
  } else {
    pop_vec <- as.numeric(pop)
    off_str <- "log(pop[t]) + "
    off1    <- "log(pop[1]) + "
  }
  zeros<-rep(0,N)

  # Build model string using paste / paste0 with inline precision as you had it
  model_lines <- c(
    "model{",
    "  C <- 10000",
    "  # First time point",
    "  zeros[1] ~ dpois(-ll[1] + C)",
    "  u[1] <- 1 / (1 + r * mu[1])",
    "  LogTruncNB[1] <- 1/r * log(u[1]) + Y[1] * log(1 - u[1]) + loggam(Y[1] + 1/r) - loggam(1/r) - loggam(Y[1] + 1) - log(1 - (1 + r * mu[1])^(-1/r))",
    "  ind[1] <- step(Y[1] - 0.5)",
    "  l1[1] <- (1 - ind[1]) * log(1 - pi[1])",
    "  l2[1] <- ind[1] * (log(pi[1]) + LogTruncNB[1])",
    "  ll[1] <- l1[1] + l2[1]",
    "  mu[1] <- mu0[1] + Z[1] * mu1[1]",
    "  mu0[1] <- exp(lambda0[1])",
    "  mu1[1] <- 0",
    paste0("  lambda0[1] <- ", off1, "inprod(Xc[1,1:Kc], beta[1:Kc])"),
    paste0("  pi[1] <- ilogit(", off1, "inprod(Xz[1,1:Kz], delta[1:Kz]))"),
    "",
    "  # Subsequent time points",
    "  for(t in 2:N){",
    "    zeros[t] ~ dpois(-ll[t] + C)",
    "    u[t] <- 1 / (1 + r * mu[t])",
    "    LogTruncNB[t] <- 1/r * log(u[t]) + Y[t] * log(1 - u[t]) + loggam(Y[t] + 1/r) - loggam(1/r) - loggam(Y[t] + 1) - log(1 - (1 + r * mu[t])^(-1/r))",
    "    ind[t] <- step(Y[t] - 0.5)",
    "    l1[t] <- (1 - ind[t]) * log(1 - pi[t])",
    "    l2[t] <- ind[t] * (log(pi[t]) + LogTruncNB[t])",
    "    ll[t] <- l1[t] + l2[t]",
    "    mu[t] <- mu0[t] + Z[t] * mu1[t]",
    "    mu0[t] <- exp(lambda0[t])",
    "    mu1[t] <- mu[t-1]",
    paste0("    lambda0[t] <- ", off_str, "inprod(Xc[t,1:Kc], beta[1:Kc])"),
    paste0("    pi[t] <- ilogit(", off_str, "inprod(Xz[t,1:Kz], delta[1:Kz]))"),
    "  }",
    "",
    "  # Latent self-excitation indicator Z with Markov-like structure",
    "  Z[1] ~ dbern(p[1])",
    "  p[1] ~ dunif(0,1)",
    "  for(t in 2:N){",
    "    Z[t] ~ dbern(p[t])",
    "    p[t] <- p[1] * (p11 + p00 - 1)^(t-1) + (1 - p00) * ((1 - (p11 + p00 - 1)^(t-1)) / (2 - p11 - p00))",
    "  }",
    "",
    "  # Priors",
    paste0("  r ~ dgamma(", r_prior_shape, ", ", r_prior_rate, ")"),
    "  eta ~ dbeta(1,1)",
    paste0("  p00 ~ dbeta(", p_priors, ", ", p_priors, ")"),
    paste0("  p11 ~ dbeta(", p_priors, ", ", p_priors, ")"),
    paste0("  for(k in 1:Kc){ beta[k]  ~ dnorm(", beta_prior_mean, ", 1/(", beta_prior_sd, "^2)) }"),
    paste0("  for(k in 1:Kz){ delta[k] ~ dnorm(", delta_prior_mean, ", 1/(", delta_prior_sd, "^2)) }"),
    "}",
    ""
  )
  model_string <- paste(model_lines, collapse = "\n")

  model_file <- tempfile(fileext = ".bug")
  writeLines(model_string, model_file)
  on.exit(unlink(model_file), add = TRUE)

  # Initial values
  if (is.null(beta_init))  beta_init  <- lapply(1:n_chains, function(i) rep(0, Kc))
  if (is.null(delta_init)) delta_init <- lapply(1:n_chains, function(i) rep(0, Kz))
  if (is.null(r_init))     r_init     <- seq(0.5, 0.5 + 0.5*(n_chains - 1), length.out = n_chains)

  inits <- lapply(1:n_chains, function(i) list(
    beta  = beta_init[[i]],
    delta = delta_init[[i]],
    r     = r_init[i]
  ))

  data4Jags <- list(
    Y  = cases,
    N  = N,
    Xc = Xc,
    Xz = Xz,
    pop = pop_vec,
    Kc = Kc,
    Kz = Kz,
    zeros=zeros
  )

  jags.out <- R2jags::jags(
    data               = data4Jags,
    inits              = inits,
    parameters.to.save = save_params,
    model.file         = model_file,
    n.iter             = n_iter,
    n.burnin           = n_burnin,
    n.chains           = n_chains,
    n.thin             = n_thin
  )


  full_summary <- as.data.frame(jags.out$BUGSoutput$summary)
  full_summary$dic <- jags.out$BUGSoutput$DIC

  # Filter out Z[...] rows for the main summary view
  summary_df <- full_summary[!grepl("^Z\\[", rownames(full_summary)), , drop = FALSE]


  # WAIC attempt (fallback-safe)
  s <- tryCatch({
    rjags::jags.samples(jags.out$model, c("WAIC", "deviance"), type = "mean", n.iter = 1000)
  }, error = function(e) NULL)

  if (!is.null(s) && !is.null(s$WAIC) && !is.null(s$deviance)) {
    p_waic <- sum(s$WAIC)
    dev    <- sum(s$deviance)
    waic_vals <- round(c(waic = dev + p_waic, p_waic = p_waic), 1)
  } else {
    waic_vals <- c(waic = NA, p_waic = NA)
  }




  ret <- list(
    mcmc_summary      = summary_df,
    mcmc_summary_full = full_summary,
    dic               = summary_df$dic[1],
    waic              = waic_vals,
    raw_output        = jags.out
  )

  # Optional Z plot
  if (plot_Z) {
    # Extract posterior mean Z
    Z_mean <- jags.out$BUGSoutput$mean$Z
    # Prepare date/index
    if (is.null(dates)) {
      dates_plot <- seq_len(length(Z_mean))
    } else {
      # try to coerce to Date, fallback to raw
      safe_date <- try(as.Date(dates), silent = TRUE)
      if (!inherits(safe_date, "try-error") && length(safe_date) == length(Z_mean)) {
        dates_plot <- safe_date
      } else if (length(dates) == length(Z_mean)) {
        dates_plot <- dates
      } else {
        stop("Provided 'dates' must match length of Z (number of timepoints).")
      }
    }

    # Build plotting data
    if (is.null(dim(Z_mean)) || length(dim(Z_mean)) == 1) {
      df_plot <- data.frame(date = dates_plot, value = as.numeric(Z_mean))
    } else {
      d <- reshape2::melt(Z_mean)
      # assume first dimension is time index
      if (is.numeric(d[[1]]) && length(dates_plot) >= max(as.integer(d[[1]]))) {
        d$date <- dates_plot[as.integer(d[[1]])]
      } else {
        d$date <- dates_plot
      }
      df_plot <- data.frame(date = d$date, value = d$value)
    }

    sp <- ggplot2::ggplot(df_plot, ggplot2::aes(x = date, y = value)) +
      ggplot2::geom_point() +
      ggplot2::geom_line() +
      ggplot2::geom_hline(yintercept = 0.5, linetype = "dashed") +
      ggplot2::theme_minimal() +
      ggplot2::labs(title = "Posterior Mean of Latent Z over Time",
                    x = if (!is.null(dates)) "Date" else "Index",
                    y = expression(E[Z]))
    ret$plot_Z <- sp

    return(ret)
  }
}

Try the sparsesurv package in your browser

Any scripts or data that you put into this service are public.

sparsesurv documentation built on Sept. 11, 2025, 9:11 a.m.