R/format.R

Defines functions format_samples_with_dates combine_tv_and_static_params format_simulation_output format_fit

Documented in combine_tv_and_static_params format_fit format_samples_with_dates format_simulation_output

#' Format Posterior Samples
#'
#' @description
#' Summaries posterior samples and adds additional custom variables.
#'
#' @param posterior_samples A list of posterior samples as returned by
#' [format_simulation_output()].
#'
#' @param horizon Numeric, forecast horizon.
#'
#' @param shift Numeric, the shift to apply to estimates.
#'
#' @inheritParams calc_summary_measures
#' @importFrom data.table fcase rbindlist
#' @importFrom lubridate days
#' @importFrom futile.logger flog.info
#' @return A list of samples and summarised posterior parameter estimates.
#' @keywords internal
format_fit <- function(posterior_samples, horizon, shift, CrIs) {
  format_out <- list()
  # bind all samples together
  format_out$samples <- data.table::rbindlist(
    posterior_samples,
    fill = TRUE, idcol = "variable"
  )

  if (is.null(format_out$samples$strat)) {
    format_out$samples <- format_out$samples[, strat := NA]
  }
  # add type based on horizon
  format_out$samples <- format_out$samples[
    ,
    type := data.table::fcase(
      date > (max(date, na.rm = TRUE) - horizon),
      "forecast",
      date > (max(date, na.rm = TRUE) - horizon - shift),
      "estimate based on partial data",
      is.na(date), NA_character_,
      default = "estimate"
    )
  ]

  # summarise samples
  format_out$summarised <- calc_summary_measures(format_out$samples,
    summarise_by = c("date", "variable", "strat", "type"),
    order_by = c("variable", "date"),
    CrIs = CrIs
  )
  format_out
}

#' Format Simulation Output from Stan
#'
#' @description
#' Formats simulation output from Stan models into structured data.tables with
#' dates. This is an internal function used by [simulate_infections()] and
#' [forecast_infections()] to process simulation results.
#'
#' This differs from [get_samples()] in that it's designed for simulation
#' outputs which have different array structures (especially with
#' `drop_length_1 = TRUE`) and need different date ranges for different
#' parameters.
#'
#' @param data A list of the data supplied to the simulation.
#'
#' @param reported_dates A vector of dates to report estimates for.
#'
#' @param imputed_dates A vector of dates to report imputed reports for.
#'
#' @param reported_inf_dates A vector of dates to report infection estimates
#' for.
#'
#' @param drop_length_1 Logical; drop dimensions of length 1 in arrays extracted
#' from the stan fit. Used in simulations where there's only 1 realization.
#'
#' @param merge if TRUE, merge samples into a single data.table using
#' rbindlist. If FALSE returns a list of samples by parameter.
#'
#' @inheritParams extract_samples
#' @return A list of `<data.frame>`'s each containing the simulated trajectories
#' of each parameter, or a single merged data.table if merge = TRUE.
#' @importFrom rstan extract
#' @importFrom data.table data.table
#' @keywords internal
format_simulation_output <- function(stan_fit, data, reported_dates,
                                     imputed_dates, reported_inf_dates,
                                     drop_length_1 = FALSE, merge = FALSE) {
  # extract sample from stan object
  samples <- extract_samples(stan_fit)

  ## drop initial length 1 dimensions if requested
  if (drop_length_1) {
    samples <- lapply(samples, function(x) {
      if (length(dim(x)) > 1 && dim(x)[1] == 1) dim(x) <- dim(x)[-1]
      x
    })
  }

  for (data_name in names(data)) {
    if (!(data_name %in% names(samples))) {
      samples[[data_name]] <- data[[data_name]]
    }
  }

  # construct reporting list
  out <- list()
  # report infections, and R
  out$infections <- extract_latent_state(
    "infections",
    samples,
    reported_inf_dates
  )
  out$infections <- out$infections[date >= min(reported_dates)]
  out$reported_cases <- extract_latent_state(
    "imputed_reports",
    samples,
    imputed_dates
  )
  if ("estimate_r" %in% names(data)) {
    if (data$estimate_r == 1) {
      out$R <- extract_latent_state(
        "R",
        samples,
        reported_dates
      )
      if (data$bp_n > 0) {
        out$breakpoints <- extract_latent_state(
          "bp_effects",
          samples,
          1:data$bp_n
        )
        out$breakpoints <- out$breakpoints[
          ,
          strat := date
        ][, c("time", "date") := NULL]
      }
    } else {
      out$R <- extract_latent_state(
        "gen_R",
        samples,
        reported_dates
      )
    }
  }
  out$growth_rate <- extract_latent_state(
    "r",
    samples,
    reported_dates[-1]
  )
  incomplete_dates <- unique(out$growth_rate[is.na(value), ][["date"]])
  out$growth_rate[date %in% incomplete_dates, value := NA]
  if (data$week_effect > 1) {
    out$day_of_week <- extract_latent_state(
      "day_of_week_simplex",
      samples,
      1:data$week_effect
    )
    out$day_of_week <- out$day_of_week[, value := value * data$week_effect]
    out$day_of_week <- out$day_of_week[, strat := date][
      ,
      c("time", "date") := NULL
    ]
  }
  if (data$delay_n_p > 0) {
    out$delay_params <- extract_latent_state(
      "delay_params", samples, seq_len(data$delay_params_length)
    )
    out$delay_params <-
      out$delay_params[, strat := as.character(time)][, time := NULL][
        ,
        date := NULL
      ]
  }
  # Auto-detect and extract all static parameters from params matrix
  all_params <- extract_parameters(samples, args = data)
  if (!is.null(all_params)) {
    # Get unique variable names
    var_names <- unique(all_params$variable)

    for (var in var_names) {
      result <- all_params[variable == var]
      if (nrow(result) > 0) {
        out[[var]] <- result
      }
    }
  }
  out
}

#' Combine time-varying and static parameters
#'
#' @description Internal helper that combines time-varying parameters (which
#'   get their variable name from list keys) with static parameters (which
#'   already have a variable column from extractors).
#'
#' @param time_varying_list Named list of time-varying parameter data.tables
#' @param raw_samples Raw samples from extract_samples()
#' @param args Model arguments containing delay and parameter specifications
#'
#' @return A `data.table` combining all parameters with variable column
#' @keywords internal
combine_tv_and_static_params <- function(time_varying_list, raw_samples, args) {
  combined_tv <- data.table::rbindlist(
    time_varying_list, fill = TRUE, idcol = "variable"
  )

  static_params <- list(
    extract_delays(raw_samples, args = args),
    extract_parameters(raw_samples, args = args)
  )
  data.table::rbindlist(c(list(combined_tv), static_params), fill = TRUE)
}

#' Format raw Stan samples with dates and metadata
#'
#' @description Internal helper that extracts Stan parameters, adds dates to
#'   time-varying parameters, and combines into a single long-format data.table.
#'
#' @param raw_samples Raw samples from extract_samples()
#' @param args Model arguments (from object$args)
#' @param observations Observation data with dates
#'
#' @return A `data.table` in long format with dates and metadata
#' @importFrom rlang %||%
#' @keywords internal
format_samples_with_dates <- function(raw_samples, args, observations) {
  # Reported dates cover the observation period plus any forecast horizon
  reported_dates <- seq(
    min(observations$date),
    max(observations$date) + args$horizon,
    by = "days"
  )
  # Full dates include the seeding period before observations
  dates <- seq(
    min(observations$date) - args$seeding_time,
    max(reported_dates),
    by = "days"
  )

  # Extract each parameter into a data.table
  out <- list()

  # Infections (for estimate_infections) - extract full time series first
  infections <- extract_latent_state("infections", raw_samples, dates)

  # Reported cases (for estimate_infections)
  out$reported_cases <- extract_latent_state(
    "imputed_reports", raw_samples, reported_dates[args$imputed_times]
  )

  # R (reproduction number) - try R first (renewal model), then gen_R (backcalc)
  R_unadjusted <- extract_latent_state("R", raw_samples, reported_dates)
  using_renewal_model <- !is.null(R_unadjusted)
  if (!using_renewal_model) {
    R_unadjusted <- extract_latent_state("gen_R", raw_samples, reported_dates)
  }

  # Extract adjusted Rt if pop adjustment enabled and using renewal model
  # (R_adj only calculated in Stan when estimate_r > 0 && use_pop > 0)
  if (using_renewal_model && args$use_pop > 0) {
    R_adj <- extract_latent_state("R_adj", raw_samples, reported_dates)
    out$R <- R_adj
    out$R_unadjusted <- R_unadjusted
  } else {
    out$R <- R_unadjusted
  }

  # Trim infections to reported dates
  if (!is.null(infections)) {
    out$infections <- infections[date >= min(reported_dates)]
  }

  # Breakpoints (if present in model)
  if (args$bp_n > 0) {
    breakpoints <- extract_latent_state("bp_effects", raw_samples, 1:args$bp_n)
    if (!is.null(breakpoints)) {
      out$breakpoints <- breakpoints[, strat := date][
        , c("time", "date") := NULL
      ]
    }
  }

  # Growth rate
  growth_rate <- extract_latent_state("r", raw_samples, reported_dates[-1])
  if (!is.null(growth_rate)) {
    incomplete_dates <- unique(growth_rate[is.na(value), ][["date"]])
    growth_rate[date %in% incomplete_dates, value := NA]
  }
  out$growth_rate <- growth_rate

  # Day of week effect
  if (args$week_effect > 1) {
    day_of_week <- extract_latent_state(
      "day_of_week_simplex", raw_samples, 1:args$week_effect
    )
    if (!is.null(day_of_week)) {
      day_of_week <- day_of_week[, value := value * args$week_effect]
      out$day_of_week <- day_of_week[, strat := date][
        , c("time", "date") := NULL
      ]
    }
  }


  # Combine time-varying and static parameters
  combined <- combine_tv_and_static_params(out, raw_samples, args)

  # Add strat column if missing
  if (is.null(combined$strat)) {
    combined <- combined[, strat := NA]
  }

  # Add type column based on horizon
  horizon <- args$horizon %||% 0
  shift <- args$seeding_time

  combined <- combined[
    ,
    type := data.table::fcase(
      date > (max(date, na.rm = TRUE) - horizon),
      "forecast",
      date > (max(date, na.rm = TRUE) - horizon - shift),
      "estimate based on partial data",
      is.na(date), NA_character_,
      default = "estimate"
    )
  ]

  combined[]
}

Try the EpiNow2 package in your browser

Any scripts or data that you put into this service are public.

EpiNow2 documentation built on June 17, 2026, 1:07 a.m.