R/pdlm.R

Defines functions pdlm

Documented in pdlm

#' Build a Predictive Dynamic Linear Model (pdlm) for wastewater-based epidemiology
#'
#' Constructs a dynamic linear model (DLM) object using the \pkg{dlm} package.
#'
#' @param data A data frame containing the variables in the model.
#' @param formula An object of class "formula" describing the model to be fitted.
#' @param lags A nonnegative integer indicating the lag of the latent state in the model.
#' @param log10 Logical; if TRUE, a log10 transformation is applied to the entire dataset.
#' @param date An optional vector of date indices of the data.
#' @param prior An optional list specifying the prior mean vector and covariance structure of the latent state.
#' If not provided, a naive prior is used.
#' @param equal.state.var Logical; if TRUE, the same variance is assumed across all state components.
#' @param equal.obs.var Logical; if TRUE, the same variance is assumed across all observation components.
#' @param init_params An optional list of initial parameters for the model. Should include Ft, Wt, and Vt:
#' transition coefficients, state variance, and observation variance components respectively.
#' @param auto_init Logical; if TRUE, naive initial parameters are used.
#' @param control An optional list of control parameters for \code{optim()}.
#'
#' @return A \code{dlm} object with additional attributes:
#' \describe{
#'   \item{formula}{The fitted formula.}
#'   \item{lags}{Number of lags.}
#'   \item{data}{The input data.}
#'   \item{date}{The input vector of dates.}
#'   \item{parameters}{A list of estimated parameters.}
#'   \item{logLik}{Log-likelihood of the fitted model.}
#'   \item{aic}{Akaike Information Criterion.}
#'   \item{bic}{Bayesian Information Criterion.}
#'   \item{convergence}{The convergence code from \code{optim}.}
#'   \item{model}{The final \code{dlm} object.}
#'   \item{filter}{Output from \code{\link[dlm]{dlmFilter}}.}
#'   \item{ypred}{One-step-ahead predictions.}
#'   \item{var.pred}{Variance of the predictions.}
#' }
#'
#' @examples
#' \donttest{
#' data <- wastewaterhealthworker[wastewaterhealthworker$Code == "TC",]
#' data$SampleDate <- as.Date(data$SampleDate)
#' fit <- pdlm(
#'   formula=HealthWorkerCaseCount~WW.tuesday+WW.thursday,
#'   data = data,
#'   lags = 2,
#'   equal.state.var=FALSE,
#'   equal.obs.var=FALSE,
#'   log10=TRUE,
#'   date = "SampleDate")
#' summary(fit)
#' plot(fit, conf.int = TRUE)
#' }
#'
#' @export


pdlm <- function(data,
                 formula,
                 lags=0,
                 log10=TRUE,
                 date = NULL,
                 prior= list(),
                 equal.state.var = TRUE,
                 equal.obs.var = TRUE,
                 init_params = list(),
                 auto_init = TRUE,
                 control = list(maxit = 500)){

  if (lags < 0)
    stop('lag must be nonnegative')

  if (!is.data.frame(data))
    stop("data must be a data frame.")

  vars <- all.vars(formula)
  k <- length(vars)
  nW <- ifelse(equal.state.var, 2, k)
  nV <- ifelse(equal.obs.var, 2, k)
  npars <- nW + nV + 1 + (lags + 1) * 2

  missing_vars <- setdiff(vars, names(data))


  if (length(missing_vars) > 0){
    stop("The following variables are missing from the data: ",
         paste(missing_vars, collapse = ", "))}

  y <- as.matrix(data[vars])

  if (log10){
    y[, 1] <- y[, 1]+1
    y <- log(y, base=10)
  }


  if (auto_init) {
    init_params <- list()
    init_params$Ft <- rep(0, times=1+(lags+1)*2)
    init_params$Wt <- rep(1, times=nW)
    init_params$Vt <- rep(1, times=nV)}
  else if (!is.list(init_params)){
    stop('init_params must be a list containing Ft, Wt and Vt components')
  }
  else{if (length(init_params$Ft) != 1 + (lags + 1)*2 )
        stop(sprintf("Transition matrix coefficents require %d parameters, but receive %d",
             1 + (lags + 1)*2, length(init_params$Ft)))
      if (length(init_params$Wt) != nW )
        stop(sprintf("State covariance structure requires %d parameters, but receive %d",
             nW, length(init_params$Wt)))
      if (length(init_params$Vt) != nV )
        stop(sprintf("Observation convariance structure requires %d parameters, but receive %d",
             nV, length(init_params$Vt)))}

  if (!all(sapply(init_params[c('Wt', 'Vt')], function(x) all(x > 0)))) {
    stop("Variance component parameters must be positive.")
  }

  init_params$Wt <- log(init_params$Wt)
  init_params$Vt <- log(init_params$Vt)
  params <- c(init_params$Ft, init_params$Wt, init_params$Vt)
  C0 <- getW(x=rep(10**-6, times=k), k, lags)
  C0[1,1] <- 10**-6
  m0 <- c(1, rep(x=y[1, ], rep(x=lags+1, k)))
  prior <- utils::modifyList(list(m0=m0, C0=C0), prior)

  build_model <- function(par) {
    build_pdlm(
      params = par,
      lags=lags,
      equal.state.var=equal.state.var,
      equal.obs.var=equal.obs.var,
      nW=nW,
      nV=nV,
      k=k,
      prior = prior)}

  fit <- dlm::dlmMLE(
    y = y,
    parm = params,
    method = "L-BFGS-B",
    build = build_model,
    control = control,
    debug = FALSE
    )

  final_model <- build_model(fit$par)
  filter_out <- dlm::dlmFilter(y, final_model)


  pars.est<- fit$par
  n.coef <- 1+(lags+1)*2
  pars.list <- list(Ft=pars.est[1:n.coef],
                   Wt=exp(pars.est[n.coef+1:nW]),
                   Vt=exp(pars.est[n.coef+nW+1:nV]))


  ypred <- t(final_model$FF %*% final_model$GG %*% t(filter_out$m[-1,]))[, 1, drop=FALSE]

  var.pred <- c()

  for(t in 1:(nrow(y)) ){
    cov.filter <- final_model$FF%*%
      ( filter_out$U.R[[t]]%*%diag(filter_out$D.R[t,]^2) %*%t(filter_out$U.R[[t]]))%*%
      t(final_model$FF) + exp(pars.est[n.coef+nW+1])

    var.pred <- c(var.pred,
                  cov.filter[1, 1]#cond.var.filer[1,1]- t(cond.var.filer[1, -1]) %*% solve(cond.var.filer[-1, -1]) %*% cond.var.filer[1, -1]
                    )
    }

  if (log10){
    ypred <- 10**ypred-1}

  colnames(ypred) <- paste0(vars[1], '_pred')

  logLik_val <- -fit$value
  aic_val <- 2 * fit$value + 2 * length(fit$par)
  bic_val <- 2 * fit$value + log(nrow(y)) * length(fit$par)
  structure(
    list(
      formula=formula,
      log10=log10,
      lags=lags,
      data = data,
      date = date,
      parameters = pars.list,
      logLik = logLik_val,
      aic = aic_val,
      bic = bic_val,
      convergence = fit$convergence,
      model = final_model,
      filter = filter_out,
      ypred = ypred,
      var.pred=var.pred,
      call = match.call()
    ),
    class = "pdlm")}

Try the dlmwwbe package in your browser

Any scripts or data that you put into this service are public.

dlmwwbe documentation built on June 8, 2025, 10:07 a.m.