R/forecast.R

Defines functions forecast_n_strain forecast

Documented in forecast forecast_n_strain

#' Forecast using branching processes at a target date
#'
#' @param models A model as supplied by [fv_model()]. If not supplied the
#' default for that strain is used. If multiple strain models are being forecast
#' then `models` should be a list models.
#'
#' @param data_list A function that returns a list of data as ingested by the
#' `inits` and `fit` function. Must use arguments as defined in
#' [fv_as_data_list()]. If not supplied the package default [fv_as_data_list()]
#' is used.
#'
#' @param inits A function that returns a function to samples initial
#' conditions with the same arguments as [fv_inits()]. If not supplied the
#' package default [fv_inits()] is used.
#'
#' @param fit A function that fits the supplied model with the same arguments
#' and return values as [fv_sample()]. If not supplied the
#' package default [fv_sample()] is used which performs MCMC sampling using
#' [cmdstanr].
#'
#' @param posterior A function that summarises the output from the supplied
#' fitting function with the same arguments and return values (depending on
#' the requirement for downstream package functionality to function) as
#' [fv_tidy_posterior()]. If not supplied the package default
#' [fv_tidy_posterior()] is used.
#'
#' @param extract_forecast A function that extracts the forecast from
#' the summarised `posterior`. If not supplied the package default
#' [fv_extract_forecast()] is used.
#'
#' @param forecast_date Date at which to forecast. Defaults to the
#' maximum date in `obs`.
#'
#' @param strains Integer number of strains to use. Defaults to 2. Current
#' maximum is 2. A numeric vector can be passed if forecasts from multiple
#' strain models are desired.
#'
#' @param keep_fit Logical, defaults to `TRUE`. Should the stan model fit be
#' kept and returned. Dropping this can substantially reduce memory usage in
#' situations where multiple models are being fit.
#'
#' @param id ID to assign to this forecast. Defaults to 0.
#'
#' @param ... Additional parameters passed to [fv_sample()].
#'
#' @return A `data.frame` containing the output of [fv_sample()] in each row as
#' well as the summarised posterior, forecast and information about the
#' parameters specified.
#'
#' @family forecast
#' @inheritParams filter_by_availability
#' @inheritParams fv_as_data_list
#' @inheritParams fv_sample
#' @inheritParams fv_tidy_posterior
#' @export
#' @importFrom purrr map transpose reduce map_chr safely
#' @examplesIf interactive()
#' options(mc.cores = 4)
#'
#' forecasts <- forecast(
#'   germany_covid19_delta_obs,
#'   forecast_date = as.Date("2021-06-12"),
#'   horizon = 4,
#'   strains = c(1, 2),
#'   adapt_delta = 0.99,
#'   max_treedepth = 15,
#'   variant_relationship = "scaled"
#' )
#'
#' # inspect forecasts
#' forecasts
#'
#' # extract the model summary
#' summary(forecasts, type = "model")
#'
#' # plot case posterior predictions
#' plot(forecasts, log = TRUE)
#'
#' # plot voc posterior predictions
#' plot(forecasts, type = "voc_frac")
#'
#' # extract the case forecast
#' summary(forecasts, type = "cases", forecast = TRUE)
forecast <- function(obs,
                     forecast_date = max(obs$date),
                     seq_date = forecast_date, case_date = forecast_date,
                     data_list = forecast.vocs::fv_as_data_list,
                     inits = forecast.vocs::fv_inits,
                     fit = forecast.vocs::fv_sample,
                     posterior = forecast.vocs::fv_tidy_posterior,
                     extract_forecast = forecast.vocs::fv_extract_forecast,
                     horizon = 4, r_init = c(0, 0.25), r_step = 1,
                     r_forecast = TRUE, beta = c(0, 0.1), lkj = 0.5,
                     period = NULL, special_periods = c(),
                     voc_scale = c(0, 0.2), voc_label = "VOC", strains = 2,
                     variant_relationship = "correlated", overdispersion = TRUE,
                     models = NULL, likelihood = TRUE, output_loglik = FALSE,
                     debug = FALSE, keep_fit = TRUE, scale_r = 1, digits = 3,
                     timespan = 7, probs = c(0.05, 0.2, 0.8, 0.95), id = 0,
                     ...) {
  if (!is.null(models)) {
    if (length(models) == 1 & length(strains) == 1) {
      models <- list(models)
    }
    stopifnot(
      "Number of models supplied must be equal to the numer of strain
       forecasts specified." = length(models) == length(strains)
    )
  }
  stopifnot(
    "Strains is not a unique vector" =
      length(strains) == length(unique(strains))
  )

  # resolve  data availability
  target_obs <- filter_by_availability(
    obs,
    date = forecast_date,
    seq_date = seq_date,
    case_date = case_date
  )

  # format data and fit models
  data <- data_list(target_obs,
    horizon = horizon,
    r_init = r_init,
    r_step = r_step,
    r_forecast = r_forecast,
    beta = beta,
    lkj = lkj,
    voc_scale = voc_scale,
    period = period,
    special_periods = special_periods,
    variant_relationship = variant_relationship,
    overdispersion = overdispersion,
    likelihood = likelihood,
    output_loglik = output_loglik,
    debug = debug
  )

  out <- data.table(
    id = id,
    forecast_date = forecast_date,
    strains = strains,
    overdispersion = overdispersion,
    variant_relationship = variant_relationship,
    r_init = paste(r_init, collapse = ", ", sep = ", "),
    voc_scale = paste(voc_scale, collapse = ", ", sep = ", ")
  )

  # forecast required strain models
  safe_n_forecast <- purrr::safely(forecast_n_strain)
  forecasts <- purrr::map(
    seq_along(strains),
    function(strain, ...) {
      out <- out[strain, ]
      fit <-
        safe_n_forecast(
          model = models[[strain]],
          inits = inits,
          fit = fit,
          posterior = posterior,
          extract_forecast = extract_forecast,
          strains = strains[strain],
          data = data,
          probs = probs,
          scale_r = scale_r,
          digits = digits,
          voc_label = voc_label,
          timespan = timespan,
          ...
        )
      out <- out[, `:=`(results = list(fit$result), error = list(fit$error))]
      return(out)
    },
    ...
  )

  forecasts <- rbindlist(forecasts, fill = TRUE)
  cols <- setdiff(colnames(forecasts), "results")
  forecasts <- cbind(forecasts, rbindlist(forecasts$results, fill = TRUE))
  forecasts[, results := NULL]
  if (!keep_fit) {
    suppressWarnings(forecasts[, c("fit", "data", "fit_args") := NULL])
  }
  class(forecasts) <- c("fv_forecast", class(forecasts))
  return(forecasts[])
}

#' Forecast for a single model and summarise
#'
#' @family forecast
#' @inheritParams fv_inits
#' @inheritParams forecast
#' @inheritParams fv_sample
#' @inheritParams fv_tidy_posterior
forecast_n_strain <- function(data, model = NULL,
                              inits = forecast.vocs::fv_inits,
                              fit = forecast.vocs::fv_sample,
                              posterior = forecast.vocs::fv_tidy_posterior,
                              extract_forecast = forecast.vocs::fv_extract_forecast, # nolint
                              strains = 2, voc_label = "VOC",
                              probs = c(0.05, 0.2, 0.8, 0.95),
                              digits = 3, scale_r = 1, timespan = 7, ...) {
  inits <- inits(data, strains = strains)

  if (is.null(model)) {
    model <- fv_model(strains = strains)
  }

  # fit and summarise
  fit <- fit(
    model = model, data = data, init = inits, ...
  )

  fit$posterior <- list(posterior(
    fit,
    probs = probs, voc_label = voc_label, scale_r = scale_r,
    digits = digits, timespan = timespan
  ))
  fit$forecast <- list(extract_forecast(fit$posterior[[1]]))
  return(fit)
}
epiforecasts/forecast.vocs documentation built on May 14, 2023, 2 p.m.