R/opts.R

Defines functions apply_default_cdf_cutoff filter_opts opts_list forecast_opts stan_opts stan_pathfinder_opts stan_laplace_opts stan_vb_opts stan_sampling_opts obs_opts gp_opts backcalc_opts rt_opts trunc_opts delay_opts secondary_opts gt_opts

Documented in apply_default_cdf_cutoff backcalc_opts delay_opts filter_opts forecast_opts gp_opts gt_opts obs_opts opts_list rt_opts secondary_opts stan_laplace_opts stan_opts stan_pathfinder_opts stan_sampling_opts stan_vb_opts trunc_opts

#' Generation Time Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns generation time parameters in a format for lower level model use.
#'
#' @details Because the discretised renewal equation used in the package does
#'   not support zero generation times, any distribution specified here will be
#'   left-truncated at one, i.e. the first element of the nonparametric or
#'   discretised probability distribution used for the generation time is set to
#'   zero and the resulting distribution renormalised.
#' @rdname generation_time_opts
#' @param dist A delay distribution or series of delay distributions . If no
#'   distribution is given a fixed generation time of 1 will be assumed.  If
#'   passing a nonparametric distribution the first element should be zero (see
#'   *Details* section)
#'
#' @param weight_prior Logical; if TRUE (default), any priors given in `dist`
#'   will be weighted by the number of observation data points, in doing so
#'   approximately placing an independent prior at each time step and usually
#'   preventing the posteriors from shifting. If FALSE, no weight
#'   will be applied, i.e. any parameters in `dist` will be treated as a single
#'   parameters.
#' @inheritParams apply_default_cdf_cutoff
#' @importFrom cli cli_warn cli_abort col_blue
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] [Gamma()] [LogNormal()] [Fixed()]
#' @export
#' @examples
#' # default settings with a fixed generation time of 1
#' generation_time_opts()
#'
#' # A fixed gamma distributed generation time
#' generation_time_opts(Gamma(mean = 3, sd = 2, max = 14))
#'
#' # An uncertain gamma distributed generation time
#' generation_time_opts(
#'   Gamma(
#'     shape = Normal(mean = 3, sd = 1),
#'     rate = Normal(mean = 2, sd = 0.5),
#'     max = 14
#'   )
#' )
#'
#' # An example generation time
#' gt_opts(example_generation_time)
gt_opts <- function(dist = Fixed(1), default_cdf_cutoff = 0.001,
                    weight_prior = TRUE) {
  if (missing(dist)) {
    cli_warn(
      c(
        "!" = "No generation time distribution given. Using a fixed generation
        time of 1 day, i.e. the reproduction number is the same as the daily
        growth rate.",
        "i" = "If this was intended then this warning can be
        silenced by setting {.var dist = Fixed(1)}'."
      )
    )
  }
  ## apply default CDF cutoff if `dist` is unconstrained
  dist <- apply_default_cdf_cutoff(
    dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
  )
  attr(dist, "weight_prior") <- weight_prior
  attr(dist, "class") <- c("generation_time_opts", class(dist))
  check_generation_time(dist)
  return(dist)
}

#' @rdname generation_time_opts
#' @export
generation_time_opts <- gt_opts

#' Secondary Reports Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns a list of options defining the secondary model used in
#' [estimate_secondary()]. This model is a combination of a convolution of
#' previously observed primary reports combined with current primary reports
#' (either additive or subtractive). It can optionally be cumulative. See the
#' documentation of `type` for sensible options to cover most use cases and the
#' returned values of [secondary_opts()] for all currently supported options.
#'
#' @param type A character string indicating the type of observation the
#' secondary reports are. Options include:
#'
#' - "incidence": Assumes that secondary reports equal a convolution of
#' previously observed primary reported cases. An example application is deaths
#' from an infectious disease predicted by reported cases of that disease (or
#' estimated infections).
#'
#' - "prevalence": Assumes that secondary reports are cumulative and are
#' defined by currently observed primary reports minus a convolution of
#' secondary reports. An example application is hospital bed usage predicted by
#' hospital admissions.
#'
#' @param ... Overwrite options defined by type. See the returned values for all
#' options that can be passed.
#' @importFrom rlang arg_match
#' @seealso [estimate_secondary()]
#' @return A `<secondary_opts>` object of binary options summarising secondary
#' model used in [estimate_secondary()]. Options returned are `cumulative`
#' (should the secondary report be cumulative), `historic` (should a
#' convolution of primary reported cases be used to predict secondary reported
#' cases), `primary_hist_additive` (should the historic convolution of primary
#' reported cases be additive or subtractive), `current` (should currently
#' observed primary reported cases contribute to current secondary reported
#' cases), `primary_current_additive` (should current primary reported cases be
#' additive or subtractive).
#'
#' @export
#' @examples
#' # incidence model
#' secondary_opts("incidence")
#'
#' # prevalence model
#' secondary_opts("prevalence")
secondary_opts <- function(type = c("incidence", "prevalence"), ...) {
  type <- arg_match(type)
  if (type == "incidence") {
    opts <- list(
      cumulative = 0,
      historic = 1,
      primary_hist_additive = 1,
      current = 0,
      primary_current_additive = 0
    )
  } else if (type == "prevalence") {
    opts <- list(
      cumulative = 1,
      historic = 1,
      primary_hist_additive = 0,
      current = 1,
      primary_current_additive = 1
    )
  }
  opts <- modifyList(opts, list(...))
  attr(opts, "class") <- c("secondary_opts", class(opts))
  return(opts)
}

#' Delay Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns delay distributions formatted for usage by downstream
#' functions.
#' @param dist A delay distribution or series of delay distributions. Default is
#'   a fixed distribution with all mass at 0, i.e. no delay.
#' @inheritParams generation_time_opts
#' @importFrom cli cli_abort
#' @return A `<delay_opts>` object summarising the input delay distributions.
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] \code{\link{Distributions}}
#' @export
#' @examples
#' # no delays
#' delay_opts()
#'
#' # A single delay that has uncertainty
#' delay <- LogNormal(
#'   meanlog = Normal(1, 0.2),
#'   sdlog = Normal(0.5, 0.1),
#'   max = 14
#' )
#' delay_opts(delay)
#'
#' # A single delay without uncertainty
#' delay <- LogNormal(meanlog = 1, sdlog = 0.5, max = 14)
#' delay_opts(delay)
#'
#' # Multiple delays (in this case twice the same)
#' delay_opts(delay + delay)
delay_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
                       weight_prior = TRUE) {
  assert_class(dist, "dist_spec")
  ## apply default CDF cutoff if `dist` is unconstrained
  dist <- apply_default_cdf_cutoff(
    dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
  )
  attr(dist, "weight_prior") <- weight_prior
  attr(dist, "class") <- c("delay_opts", class(dist))
  check_stan_delay(dist)
  return(dist)
}

#' Truncation Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns a truncation distribution formatted for usage by
#' downstream functions. See [estimate_truncation()] for an approach to
#' estimate these distributions.
#'
#' @param dist A delay distribution or series of delay distributions reflecting
#' the truncation. It can be specified using the probability distributions
#' interface in `EpiNow2` (See `?EpiNow2::Distributions`) or estimated using
#' [estimate_truncation()], which returns a `dist` object, suited
#' for use here out-of-box. Default is a fixed distribution with maximum 0, i.e.
#' no truncation.
#' @param weight_prior Logical; if TRUE, the truncation prior will be weighted
#'   by the number of observation data points, in doing so approximately placing
#'   an independent prior at each time step and usually preventing the
#'   posteriors from shifting. If FALSE (default), no weight will be applied,
#'   i.e. the truncation distribution will be treated as a single parameter.
#'
#' @inheritParams gt_opts
#' @importFrom cli cli_abort
#' @return A `<trunc_opts>` object summarising the input truncation
#' distribution.
#'
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] \code{\link{Distributions}}
#' @export
#' @examples
#' # no truncation
#' trunc_opts()
#'
#' # truncation dist
#' trunc_opts(dist = LogNormal(mean = 3, sd = 2, max = 10))
trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
                       weight_prior = FALSE) {
  assert_class(dist, "dist_spec")
  ## apply default CDF cutoff if `dist` is unconstrained
  dist <- apply_default_cdf_cutoff(
    dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
  )
  attr(dist, "weight_prior") <- weight_prior
  attr(dist, "class") <- c("trunc_opts", class(dist))
  check_stan_delay(dist)
  return(dist)
}

#' Time-Varying Reproduction Number Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the time-varying
#' reproduction number. Custom settings can be supplied which override the
#' defaults.
#'
#' @param prior A `<dist_spec>` giving the prior of the initial reproduciton
#' number. Ignored if `use_rt` is `FALSE`. Defaults to a LogNormal distributin
#' with mean of 1 and standard deviation of 1: `LogNormal(mean = 1, sd = 1)`.
#' A lower limit of 0 will be enforced automatically.
#'
#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate
#' infections and hence reported cases.
#'
#' @param rw Numeric step size of the random walk, defaults to 0. To specify a
#' weekly random walk set `rw = 7`. For more custom break point settings
#' consider passing in a `breakpoints` variable as outlined in the next section.
#'
#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be
#' used if present as a `breakpoint` variable in the input data. Break points
#' should be defined as 1 if present and otherwise 0. By default breakpoints
#' are fit jointly with a global non-parametric effect and so represent a
#' conservative estimate of break point changes (alter this by setting
#' `gp = NULL`).
#'
#' @param pop Integer, defaults to 0. Susceptible population initially present.
#' Used to adjust Rt estimates in the forecast horizon based on the
#' proportion of the population that is susceptible. When set to 0 no
#' population adjustment is done.
#'
#' @param gp_on Character string, defaulting to  "R_t-1". Indicates how the
#' Gaussian process, if in use, should be applied to Rt. Currently supported
#' options are applying the Gaussian process to the last estimated Rt (i.e
#' Rt = Rt-1 * GP), and applying the Gaussian process to a global mean (i.e Rt
#' = R0 * GP). Both should produced comparable results when data is not sparse
#' but the method relying on a global mean will revert to this for real time
#' estimates, which may not be desirable.
#'
#' @return An `<rt_opts>` object with settings defining the time-varying
#' reproduction number.
#' @inheritParams create_future_rt
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort
#' @export
#' @examples
#' # default settings
#' rt_opts()
#'
#' # add a custom length scale
#' rt_opts(prior = LogNormal(mean = 2, sd = 1))
#'
#' # add a weekly random walk
#' rt_opts(rw = 7)
rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
                    use_rt = TRUE,
                    rw = 0,
                    use_breakpoints = TRUE,
                    future = "latest",
                    gp_on = c("R_t-1", "R0"),
                    pop = 0) {
  opts <- list(
    use_rt = use_rt,
    rw = rw,
    use_breakpoints = use_breakpoints,
    future = future,
    pop = pop,
    gp_on = arg_match(gp_on)
  )

  # replace default settings with those specified by user
  if (opts$rw > 0) {
    opts$use_breakpoints <- TRUE
  }

  if (is.list(prior) && !is(prior, "dist_spec")) {
    cli_abort(
      c(
        "!" = "Specifying {.var prior} as a list is deprecated.",
        "i" = "Use a {.cls dist_spec} instead."
      )
    )
  }

  if (opts$use_rt) {
    opts$prior <- prior
  } else if (!missing(prior)) {
    cli_warn(
      c(
        "!" = "Rt {.var prior} is ignored if {.var use_rt} is FALSE."
      )
    )
  }

  attr(opts, "class") <- c("rt_opts", class(opts))
  return(opts)
}

#' Back Calculation Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the back calculation
#' of cases. Only used if `rt = NULL`.
#'
#' @param prior A character string defaulting to "reports". Defines the prior
#' to use when deconvolving. Currently implemented options are to use smoothed
#' mean delay shifted reported cases ("reports"), to use the estimated
#' infections from the previous time step seeded for the first time step using
#' mean shifted reported cases ("infections"), or no prior ("none"). Using no
#' prior will result in poor real time performance. No prior and using
#' infections are only supported when a Gaussian process is present . If
#' observed data is not reliable then it a sensible first step is to explore
#' increasing the `prior_window` wit a sensible second step being to no longer
#' use reported cases as a prior (i.e set `prior = "none"`).
#'
#' @param prior_window Integer, defaults to 14 days. The mean centred smoothing
#' window to apply to mean shifted reports (used as a prior during back
#' calculation). 7 days is minimum recommended settings as this smooths day of
#' the week effects but depending on the quality of the data and the amount of
#' information users wish to use as a prior (higher values equalling a less
#' informative prior).
#'
#' @param rt_window Integer, defaults to 1. The size of the centred rolling
#' average to use when estimating Rt. This must be odd so that the central
#' estimate is included.
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort
#'
#' @return A `<backcalc_opts>` object of back calculation settings.
#' @export
#' @examples
#' # default settings
#' backcalc_opts()
backcalc_opts <- function(prior = c("reports", "none", "infections"),
                          prior_window = 14, rt_window = 1) {
  backcalc <- list(
    prior = arg_match(prior),
    prior_window = prior_window,
    rt_window = as.integer(rt_window)
  )
  if (backcalc$rt_window %% 2 == 0) {
    cli_abort(
      c(
        "!" = "{.var rt_window} must be odd in order to
        include the current estimate.",
        "i" = "You have supplied an even number."
      )
    )
  }
  attr(backcalc, "class") <- c("backcalc_opts", class(backcalc))
  return(backcalc)
}

#' Approximate Gaussian Process Settings
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Deprecated; use `ls` instead.
#'
#' @param ls_sd Deprecated; use `ls` instead.
#'
#' @param ls_min Deprecated; use `ls` instead.
#'
#' @param ls_max Deprecated; use `ls` instead.
#'
#' @param ls A `<dist_spec>` giving the prior distribution of the lengthscale
#' parameter of the Gaussian process kernel on the scale of days. Defaults to
#' a Lognormal distribution with mean 21 days, sd 7 days and maximum 60 days:
#' `LogNormal(mean = 21, sd = 7, max = 60)` (a lower limit of 0 will be
#' enforced automatically to ensure positivity)
#'
#' @param alpha A `<dist_spec>` giving the prior distribution of the magnitude
#' parameter of the Gaussian process kernel. Should be approximately the
#' expected standard deviation of the Gaussian process (logged Rt in case of
#' the renewal model, logged infections in case of the nonmechanistic model).
#' Defaults to a half-normal distribution with mean 0 and sd 0.01:
#' `Normal(mean = 0, sd = 0.01)` (a lower limit of 0 will be enforced
#' automatically to ensure positivity)
#'
#' @param alpha_mean Deprecated; use `alpha` instead.
#'
#' @param alpha_sd Deprecated; use `alpha` instead.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the Matern kernel ("matern"), squared exponential kernel ("se"),
#' periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic
#' kernel ("periodic").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
#' to "ou", `matern_order` will be automatically set to 1/2. Only used if
#' the kernel is set to "matern".
#'
#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn
#' Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, the proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
#' accuracy but a faster compute time (with increasing it having the first
#' effect). In general smaller posterior length scales require a higher
#' proportion of basis functions. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param boundary_scale Numeric, defaults to 1.5. Boundary scale of the
#' approximate Gaussian process. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic
#' kernel. They are only used if `kernel` is set to "periodic".
#'
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort cli_warn
#' @return A `<gp_opts>` object of settings defining the Gaussian process
#' @export
#' @examples
#' # default settings
#' gp_opts()
#'
#' # add a custom length scale
#' gp_opts(ls = LogNormal(mean = 4, sd = 1, max = 20))
#'
#' # use linear kernel
#' gp_opts(kernel = "periodic")
gp_opts <- function(basis_prop = 0.2,
                    boundary_scale = 1.5,
                    ls_mean = 21,
                    ls_sd = 7,
                    ls_min = 0,
                    ls_max = 60,
                    ls = LogNormal(mean = 21, sd = 7, max = 60),
                    alpha = Normal(mean = 0, sd = 0.01),
                    kernel = c("matern", "se", "ou", "periodic"),
                    matern_order = 3 / 2,
                    matern_type,
                    w0 = 1.0,
                    alpha_mean, alpha_sd) {
  if (!missing(matern_type)) {
    lifecycle::deprecate_stop(
      "1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
    )
  }
  if (!missing(alpha_mean)) {
    lifecycle::deprecate_stop(
      "1.7.0", "gp_opts(alpha_mean)", "gp_opts(alpha)"
    )
  }
  if (!missing(alpha_sd)) {
    lifecycle::deprecate_stop(
      "1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)"
    )
  }
  if (!missing(ls_mean) || !missing(ls_sd) || !missing(ls_min) ||
        !missing(ls_max)) {
    if (!missing(ls)) {
      cli_abort(
        c(
          "!" = "Both {.var ls} and at least one legacy argument
          ({.var ls_mean}, {.var ls_sd}, {.var ls_min}, {.var ls_max}) have been
          specified.",
          "i" = "Only one of the should be used."
        )
      )
    }
    cli_abort(c(
      "!" = "Specifying lengthscale priors via the {.var ls_mean}, {.var ls_sd},
      {.var ls_min}, and {.var ls_max} arguments is deprecated.",
      "i" = "Use the {.var ls} argument instead."
    ))
    if (ls_min > 0) {
      cli_abort(
        c(
          "!" = "Lower lengthscale bounds of greater than 0 are no longer
          supported. If this is a feature you need please open an Issue on the
          EpiNow2 GitHub repository."
        )
      )
    }
    ls <- LogNormal(mean = ls_mean, sd = ls_sd, max = ls_max)
  }

  if (!missing(matern_type)) {
    if (!missing(matern_order) && matern_type != matern_order) {
      cli_abort(
        c(
          "!" = "{.var matern_order} and {.var matern_type} must be the same, if
          both are supplied.",
          "i" = "Rather only use {.var matern_order} only."
        )
      )
    }
    matern_order <- matern_type
  }

  kernel <- arg_match(kernel)
  if (kernel == "se") {
    matern_order <- Inf
  } else if (kernel == "ou") {
    matern_order <- 1 / 2
  } else if (
    !(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
  ) {
    cli_warn(
      c(
        "!" = "Uncommon Matern kernel order supplied.",
        "i" = "Use one of `1 / 2`, `3 / 2`, or `5 / 2`" # nolint
      )
    )
  }

  gp <- list(
    basis_prop = basis_prop,
    boundary_scale = boundary_scale,
    ls = ls,
    alpha = alpha,
    kernel = kernel,
    matern_order = matern_order,
    w0 = w0
  )

  attr(gp, "class") <- c("gp_opts", class(gp))
  return(gp)
}

#' Observation Model Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the observation
#' model. Custom settings can be supplied which override the defaults.
#' @param family Character string defining the observation model. Options are
#'   Negative binomial ("negbin"), the default, and Poisson.
#' @param dispersion A `<dist_spec>` specifying a prior on the dispersion
#'   parameter of the reporting process, used only if `familiy` is "negbin".
#'   Internally parameterised such that this parameter is one over the square
#'   root of the `phi` parameter for overdispersion of the
#'   [negative binomial distribution](https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#neg-binom-2-log). # nolint
#'   Defaults to a half-normal distribution with mean of 0 and
#'   standard deviation of 0.25: `Normal(mean = 0, sd = 0.25)`. A lower limit of
#'   zero will be enforced automatically.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#'   log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#'   effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#'   7 days. This can be modified if data aggregated over a period other than a
#'   week or if data has a non-weekly periodicity.
#' @param scale A `<dist_spec>` specifying a prior on the scaling factor to be
#'   applied to map latent infections (convolved to date of report).  Defaults
#'   to a fixed value of 1, i.e. no scaling: `Fixed(1)`. A lower limit of zero
#'   will be enforced automatically. If setting to a prior distribution and no
#'   overreporting is expected, it might be sensible to set a maximum of 1 via
#'   the `max` option when declaring the distribution.
#' @param na Deprecated; use the [fill_missing()] function instead
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#'   included in the model.
#' @param phi deprecated; use `dispersion` instead
#' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood
#'   be returned by the model.
#' @importFrom rlang arg_match
#' @importFrom cli cli_inform cli_abort
#' @return An `<obs_opts>` object of observation model settings.
#' @export
#' @examples
#' # default settings
#' obs_opts()
#'
#' # Turn off day of the week effect
#' obs_opts(week_effect = TRUE)
#'
#' # Scale reported data
#' obs_opts(scale = Normal(mean = 0.2, sd = 0.02))
obs_opts <- function(family = c("negbin", "poisson"),
                     dispersion = Normal(mean = 0, sd = 0.25),
                     weight = 1,
                     week_effect = TRUE,
                     week_length = 7,
                     scale = Fixed(1),
                     na = c("missing", "accumulate"),
                     likelihood = TRUE,
                     return_likelihood = FALSE,
                     phi) {
  if (!missing(phi)) {
    if (!missing(dispersion)) {
      cli::cli_abort(
        "Can't specify {.var disperion} and {.var phi}."
      )
    } else {
      lifecycle::deprecate_stop(
        "1.7.0",
        "obs_opts(phi)",
        "obs_opts(dispersion)",
        details =
          "The meaning of the `phi` and `dispersion` arguments are the same."
      )
      dispersion <- phi
    }
  }
  na_default_used <- missing(na)
  if (!na_default_used) {
    lifecycle::deprecate_stop(
      "1.7.0",
      "obs_opts(na)",
      "fill_missing()",
      details = c(
        paste0(
          "If NA values are not to be treated as missing use the ",
          "`fill_missing()` function instead."
        ),
        "This argument will be removed in the next release of EpiNow2."
      )
    )
  }
  na <- arg_match(na)
  if (na == "accumulate") {
    # nolint start: duplicate_argument_linter
    cli_inform(
      c(
        "i" = "Accumulating modelled values that correspond to NA values in the
      data by adding them to the next non-NA data point.",
        "i" = "This means that the first data point is not included in the
      likelihood but used only to reset modelled observations to zero.",
        "i" = "{col_red('If the first data point should be included in the
        likelihood this can be achieved by using the `fill_missing()` function
        with a non-zero `initial_missing` argument.')}"
      ),
      .frequency = "regularly",
      .frequency_id = "obs_opts"
    )
    # nolint end
  }
  obs <- list(
    family = arg_match(family),
    dispersion = dispersion,
    weight = weight,
    week_effect = week_effect,
    week_length = week_length,
    scale = scale,
    accumulate = as.integer(na == "accumulate"),
    likelihood = likelihood,
    return_likelihood = return_likelihood,
    na_as_missing_default_used = na_default_used
  )

  for (param in c("dispersion", "scale")) {
    if (is.numeric(obs[[param]])) {
      cli_abort(
        c(
          "!" = "Specifying {.var {param}} as a numeric value is deprecated.",
          "i" = "Use a {.cls dist_spec} instead using {.fn Fixed()}."
        )
      )
      obs[[param]] <- Fixed(obs[[param]])
    } else if (is.list(obs[[param]]) && !is(obs[[param]], "dist_spec")) {
      cli_abort(
        c(
          "!" = "Specifying {.var {param}} as a list is deprecated.",
          "i" = "Use a {.cls dist_spec} instead."
        )
      )
      obs[[param]] <- Normal(mean = obs[[param]]$mean, sd = obs[[param]]$sd)
    } else {
      assert_class(obs[[param]], "dist_spec")
    }
  }

  attr(obs, "class") <- c("obs_opts", class(obs))
  return(obs)
}

#' Stan Sampling Options
#'
#' @description `r lifecycle::badge("stable")`
#'  Defines a list specifying the arguments passed to either [rstan::sampling()]
#'  or [cmdstanr::sample()]. Custom settings can be supplied which override the
#'  defaults.
#'
#' @param cores Number of cores to use when executing the chains in parallel,
#'  which defaults to 1 but it is recommended to set the mc.cores option to be
#'  as many processors as the hardware and RAM allow (up to the number of
#'  chains).
#'
#' @param warmup Numeric, defaults to 250. Number of warmup samples per chain.
#'
#' @param samples Numeric, default 2000. Overall number of posterior samples.
#' When using multiple chains iterations per chain is samples / chains.
#'
#' @param chains Numeric, defaults to 4. Number of MCMC chains to use.
#'
#' @param control List, defaults to empty. control parameters to pass to
#' underlying `rstan` function. By default `adapt_delta = 0.9` and
#' `max_treedepth = 12` though these settings can be overwritten.
#'
#' @param save_warmup Logical, defaults to FALSE. Should warmup progress be
#' saved.
#'
#' @param seed Numeric, defaults uniform random number between 1 and 1e8. Seed
#' of sampling process.
#'
#' @param future Logical, defaults to `FALSE`. Should stan chains be run in
#' parallel using `future`. This allows users to have chains fail gracefully
#' (i.e when combined with `max_execution_time`). Should be combined with a
#' call to [future::plan()].
#'
#' @param max_execution_time Numeric, defaults to Inf (seconds). If set wil
#' kill off processing of each chain if not finished within the specified
#' timeout. When more than 2 chains finish successfully estimates will still be
#' returned. If less than 2 chains return within the allowed time then
#' estimation will fail with an informative error.
#'
#' @inheritParams stan_opts
#'
#' @param ... Additional parameters to pass to [rstan::sampling()] or
#' [cmdstanr::sample()].
#' @importFrom utils modifyList
#' @importFrom cli cli_warn
#' @return A list of arguments to pass to [rstan::sampling()] or
#' [cmdstanr::sample()].
#' @export
#' @examples
#' stan_sampling_opts(samples = 2000)
stan_sampling_opts <- function(cores = getOption("mc.cores", 1L),
                               warmup = 250,
                               samples = 2000,
                               chains = 4,
                               control = list(),
                               save_warmup = FALSE,
                               seed = as.integer(runif(1, 1, 1e8)),
                               future = FALSE,
                               max_execution_time = Inf,
                               backend = c("rstan", "cmdstanr"),
                               ...) {
  dot_args <- list(...)
  backend <- arg_match(backend)
  opts <- list(
    chains = chains,
    save_warmup = save_warmup,
    seed = seed,
    future = future,
    max_execution_time = max_execution_time
  )
  control_def <- list(adapt_delta = 0.9, max_treedepth = 12)
  control_def <- modifyList(control_def, control)
  if (any(c("iter", "iter_sampling") %in% names(dot_args))) {
    cli_warn(
      c(
        "!" = "Number of samples must be specified using the {.var samples}
      and {.var warmup} arguments rather than {.var iter} or
      {.var iter_sampliing}.",
        "i" = "Supplied {.var iter} or {.var iter_sampliing} will be ignored."
      )
    )
  }
  dot_args$iter <- NULL
  dot_args$iter_sampling <- NULL
  if (backend == "rstan") {
    opts <- c(opts, list(
      cores = cores,
      warmup = warmup,
      control = control_def,
      iter = ceiling(samples / opts$chains) + warmup
    ))
  } else if (backend == "cmdstanr") {
    opts <- c(opts, list(
      parallel_chains = cores,
      iter_warmup = warmup,
      iter_sampling = ceiling(samples / opts$chains)
    ), control_def)
  }
  opts <- c(opts, dot_args)
  return(opts)
}

#' Stan Variational Bayes Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to [rstan::vb()] or
#' [cmdstanr::variational()]. Custom settings can be supplied which override the
#' defaults.
#'
#' @param samples Numeric, default 2000. Overall number of approximate posterior
#' samples.
#'
#' @param trials Numeric, defaults to 10. Number of attempts to use
#' rstan::vb()] before failing.
#'
#' @param iter Numeric, defaulting to 10000. Number of iterations to use in
#' [rstan::vb()].
#'
#' @param ... Additional parameters to pass to [rstan::vb()] or
#' [cmdstanr::variational()], depending on the chosen backend.
#'
#' @return A list of arguments to pass to [rstan::vb()] or
#'   [cmdstanr::variational()], depending on the chosen backend.
#' @export
#' @examples
#' stan_vb_opts(samples = 1000)
stan_vb_opts <- function(samples = 2000,
                         trials = 10,
                         iter = 10000, ...) {
  opts <- list(
    trials = trials,
    iter = iter,
    output_samples = samples
  )
  opts <- c(opts, ...)
  return(opts)
}

#' Stan Laplace algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @importFrom cli cli_abort col_blue
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_laplace_opts <- function(backend = "cmdstanr",
                              trials = 10,
                              ...) {
  if (backend != "cmdstanr") {
    cli_abort(
      c(
        "!" = "Backend must be set to {col_blue(\"cmdstanr\")} to use
        the Laplace algorithm.",
        "i" = "Change {.var backend} to col_blue(\"cmdstanr\")}."
      )
    )
  }
  opts <- list(trials = trials)
  opts <- c(opts, ...)
  return(opts)
}

#' Stan pathfinder algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @importFrom cli cli_abort col_blue
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_pathfinder_opts <- function(backend = "cmdstanr",
                                 samples = 2000,
                                 trials = 10,
                                 ...) {
  if (backend != "cmdstanr") {
    cli_abort(
      c(
        "!" = "Backend must be set to {col_blue(\"cmdstanr\")} to use
        the pathfinder algorithm.",
        "i" = "Change {.var backend} to col_blue(\"cmdstanr\")}."
      )
    )
  }
  opts <- list(
    trials = trials,
    draws = samples
  )
  opts <- c(opts, ...)
  return(opts)
}

#' Stan Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to underlying stan
#' backend functions via [stan_sampling_opts()] and [stan_vb_opts()]. Custom
#' settings can be supplied which override the defaults.
#'
#' @param object Stan model object. By default uses the compiled package
#' default if using the "rstan" backend, and the default model obtained using
#' [epinow2_cmdstan_model()] if using the "cmdstanr" backend.
#'
#' @param samples Numeric, defaults to 2000. Number of posterior samples.
#' @param method A character string, defaulting to sampling. Currently supports
#' MCMC sampling ("sampling") or approximate posterior sampling via
#' variational inference ("vb") and, as experimental features if the
#' "cmdstanr" backend is used, approximate posterior sampling with the
#' laplace algorithm ("laplace") or pathfinder ("pathfinder").
#'
#' @param backend Character string indicating the backend to use for fitting
#' stan models. Supported arguments are "rstan" (default) or "cmdstanr".
#'
#' @param return_fit Logical, defaults to TRUE. Should the fit stan model be
#' returned.
#'
#' @param ... Additional parameters to pass to underlying option functions,
#'   [stan_sampling_opts()] or [stan_vb_opts()], depending on the method
#'
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort cli_warn col_blue
#' @return A `<stan_opts>` object of arguments to pass to the appropriate
#' rstan functions.
#' @export
#' @seealso [stan_sampling_opts()] [stan_vb_opts()]
#' @examples
#' # using default of [rstan::sampling()]
#' stan_opts(samples = 1000)
#'
#' # using vb
#' stan_opts(method = "vb")
stan_opts <- function(object = NULL,
                      samples = 2000,
                      method = c("sampling", "vb", "laplace", "pathfinder"),
                      backend = c("rstan", "cmdstanr"),
                      return_fit = TRUE,
                      ...) {
  method <- arg_match(method)
  backend_passed <- !missing(backend)
  backend <- arg_match(backend)
  if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) {
    cli_abort(
      c(
        "x" = "The {col_blue('cmdstanr')} R package is not installed.",
        "i" = "Install it from {.url https://github.com/stan-dev/cmdstanr}
        to use the {col_blue('cmdstanr')} backend."
      )
    )
  }
  opts <- list()
  if (!is.null(object)) {
    if (backend_passed) {
      cli_warn(
        c(
          "!" = "{.var backend} option will be ignored as a stan model
        object has been passed."
        )
      )
    }
    if (inherits(object, "stanmodel")) {
      backend <- "rstan"
    } else if (inherits(object, "CmdStanModel")) {
      backend <- "cmdstanr"
    } else {
      cli_abort(
        c(
          "!" = "{.var object} must be a stan model object."
        )
      )
    }
  } else {
    backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
    opts <- c(opts, list(backend = backend))
  }
  opts <- c(opts, list(
    object = object,
    method = method
  ))
  opts <- switch(method,
    sampling = c(
      opts, stan_sampling_opts(samples = samples, backend = backend, ...)
    ),
    vb = c(
      opts, stan_vb_opts(samples = samples, ...)
    ),
    laplace = c(
      opts, stan_laplace_opts(backend = backend, ...)
    ),
    pathfinder = c(
      opts, stan_pathfinder_opts(samples = samples, backend = backend, ...)
    )
  )

  opts <- c(opts, list(return_fit = return_fit))
  attr(opts, "class") <- c("stan_opts", class(opts))
  return(opts)
}

#' Forecast options
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to underlying stan
#' backend functions via [stan_sampling_opts()] and [stan_vb_opts()]. Custom
#' settings can be supplied which override the defaults.
#'
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#' @param accumulate Integer, the number of days to accumulate in forecasts, if
#'   any. If not given and observations are accumulated at constant frequency in
#'   the data used for fitting then the same accumulation will be used in
#'   forecasts unless set explicitly here.
#' @return A `<forecast_opts>` object of forecast setting.
#' @seealso fill_missing
#' @export
#' @examples
#' forecast_opts(horizon = 28, accumulate = 7)
forecast_opts <- function(horizon = 7, accumulate) {
  opts <- list(
    horizon = horizon
  )
  if (!missing(accumulate)) {
    opts$accumulate <- accumulate
  }
  attr(opts, "class") <- c("forecast_opts", class(opts))
  return(opts)
}

#' Forecast optiong
#'
#' @description `r lifecycle::badge("maturing")`
#' Define a list of `_opts()` to pass to [regional_epinow()] `_opts()` accepting
#' arguments. This is useful when different settings are needed between regions
#' within a single [regional_epinow()] call. Using [opts_list()] the defaults
#' can be applied to all regions present with an override passed to regions as
#' necessary (either within [opts_list()] or externally).
#'
#' @param opts An `_opts()` function call such as [rt_opts()].
#'
#' @param reported_cases A data frame containing a `region` variable
#' indicating the target regions.
#'
#' @param ... Optional override for region defaults. See the examples
#' for use case.
#'
#' @importFrom purrr list_assign
#'
#' @return A named list of options per region which can be passed to the `_opt`
#' accepting arguments of `regional_epinow`.
#' @seealso [regional_epinow()] [rt_opts()]
#' @export
#' @examples
#' # uses example case vector
#' cases <- example_confirmed[1:40]
#' cases <- data.table::rbindlist(list(
#'   data.table::copy(cases)[, region := "testland"],
#'   cases[, region := "realland"]
#' ))
#'
#' # default settings
#' opts_list(rt_opts(), cases)
#'
#' # add a weekly random walk in realland
#' opts_list(rt_opts(), cases, realland = rt_opts(rw = 7))
#'
#' # add a weekly random walk externally
#' rt <- opts_list(rt_opts(), cases)
#' rt$realland$rw <- 7
#' rt
opts_list <- function(opts, reported_cases, ...) {
  regions <- unique(reported_cases$region)
  default <- rep(list(opts), length(regions))
  names(default) <- regions
  list_assign(default, ...)
}

#' Filter Options for a Target Region
#'
#' @description `r lifecycle::badge("maturing")`
#' A helper function that allows the selection of region specific settings if
#' present and otherwise applies the overarching settings.
#'
#' @param opts Either a list of calls to an `_opts()` function or a single
#' call to an `_opts()` function.
#'
#' @param region A character string indicating a region of interest.
#'
#' @return A list of options
filter_opts <- function(opts, region) {
  if (region %in% names(opts)) {
    out <- opts[[region]]
  } else {
    out <- opts
  }
  out
}

#' Apply default CDF cutoff to a <dist_spec> if it is unconstrained
#'
#' @param dist A <dist_spec>
#' @param default_cdf_cutoff Numeric; default CDF cutoff to be used if an
#'   unconstrained distribution is passed as `dist`. If `dist` is already
#'   constrained by having a maximum or CDF cutoff this is ignored. Note that
#'   this can only be done for <dist_spec> objects with fixed parameters.
#' @param cdf_cutoff_set Logical; whether the default CDF cutoff has been set by
#'   the user; if yes and `dist` is constrained a warning is issued
#' @importFrom cli cli_inform cli_warn
#'
#' @return A <dist_spec> with the default CDF cutoff set if previously not
#'   constrained
#' @keywords internal
apply_default_cdf_cutoff <- function(dist, default_cdf_cutoff, cdf_cutoff_set) {
  if (!is_constrained(dist) && !anyNA(sd(dist))) {
    # nolint start: duplicate_argument_linter
    cli_inform(
      c(
        "i" = "Unconstrained distributon passed as a delay. ",
        "i" = "Constraining with default CDF cutoff {default_cdf_cutoff}.",
        "i" = "To silence this message, specify delay distributions
      with {.var max} or {.var default_cdf_cutoff}."
      )
    )
    # nolint end
    attr(dist, "cdf_cutoff") <- default_cdf_cutoff
  } else if (cdf_cutoff_set) {
    cli_warn(
      c(
        "!" = "Ignoring given default CDF cutoff.",
        "i" = "Distribution is already constrained."
      )
    )
  }
  dist
}
epiforecasts/EpiNow2 documentation built on June 9, 2025, 3:51 p.m.