#' Generation Time Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns generation time parameters in a format for lower level model use.
#'
#' @details Because the discretised renewal equation used in the package does
#' not support zero generation times, any distribution specified here will be
#' left-truncated at one, i.e. the first element of the nonparametric or
#' discretised probability distribution used for the generation time is set to
#' zero and the resulting distribution renormalised.
#' @rdname generation_time_opts
#' @param dist A delay distribution or series of delay distributions . If no
#' distribution is given a fixed generation time of 1 will be assumed. If
#' passing a nonparametric distribution the first element should be zero (see
#' *Details* section)
#'
#' @param weight_prior Logical; if TRUE (default), any priors given in `dist`
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight
#' will be applied, i.e. any parameters in `dist` will be treated as a single
#' parameters.
#' @inheritParams apply_default_cdf_cutoff
#' @importFrom cli cli_warn cli_abort col_blue
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] [Gamma()] [LogNormal()] [Fixed()]
#' @export
#' @examples
#' # default settings with a fixed generation time of 1
#' generation_time_opts()
#'
#' # A fixed gamma distributed generation time
#' generation_time_opts(Gamma(mean = 3, sd = 2, max = 14))
#'
#' # An uncertain gamma distributed generation time
#' generation_time_opts(
#' Gamma(
#' shape = Normal(mean = 3, sd = 1),
#' rate = Normal(mean = 2, sd = 0.5),
#' max = 14
#' )
#' )
#'
#' # An example generation time
#' gt_opts(example_generation_time)
gt_opts <- function(dist = Fixed(1), default_cdf_cutoff = 0.001,
weight_prior = TRUE) {
if (missing(dist)) {
cli_warn(
c(
"!" = "No generation time distribution given. Using a fixed generation
time of 1 day, i.e. the reproduction number is the same as the daily
growth rate.",
"i" = "If this was intended then this warning can be
silenced by setting {.var dist = Fixed(1)}'."
)
)
}
## apply default CDF cutoff if `dist` is unconstrained
dist <- apply_default_cdf_cutoff(
dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
)
attr(dist, "weight_prior") <- weight_prior
attr(dist, "class") <- c("generation_time_opts", class(dist))
check_generation_time(dist)
return(dist)
}
#' @rdname generation_time_opts
#' @export
generation_time_opts <- gt_opts
#' Secondary Reports Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns a list of options defining the secondary model used in
#' [estimate_secondary()]. This model is a combination of a convolution of
#' previously observed primary reports combined with current primary reports
#' (either additive or subtractive). It can optionally be cumulative. See the
#' documentation of `type` for sensible options to cover most use cases and the
#' returned values of [secondary_opts()] for all currently supported options.
#'
#' @param type A character string indicating the type of observation the
#' secondary reports are. Options include:
#'
#' - "incidence": Assumes that secondary reports equal a convolution of
#' previously observed primary reported cases. An example application is deaths
#' from an infectious disease predicted by reported cases of that disease (or
#' estimated infections).
#'
#' - "prevalence": Assumes that secondary reports are cumulative and are
#' defined by currently observed primary reports minus a convolution of
#' secondary reports. An example application is hospital bed usage predicted by
#' hospital admissions.
#'
#' @param ... Overwrite options defined by type. See the returned values for all
#' options that can be passed.
#' @importFrom rlang arg_match
#' @seealso [estimate_secondary()]
#' @return A `<secondary_opts>` object of binary options summarising secondary
#' model used in [estimate_secondary()]. Options returned are `cumulative`
#' (should the secondary report be cumulative), `historic` (should a
#' convolution of primary reported cases be used to predict secondary reported
#' cases), `primary_hist_additive` (should the historic convolution of primary
#' reported cases be additive or subtractive), `current` (should currently
#' observed primary reported cases contribute to current secondary reported
#' cases), `primary_current_additive` (should current primary reported cases be
#' additive or subtractive).
#'
#' @export
#' @examples
#' # incidence model
#' secondary_opts("incidence")
#'
#' # prevalence model
#' secondary_opts("prevalence")
secondary_opts <- function(type = c("incidence", "prevalence"), ...) {
type <- arg_match(type)
if (type == "incidence") {
opts <- list(
cumulative = 0,
historic = 1,
primary_hist_additive = 1,
current = 0,
primary_current_additive = 0
)
} else if (type == "prevalence") {
opts <- list(
cumulative = 1,
historic = 1,
primary_hist_additive = 0,
current = 1,
primary_current_additive = 1
)
}
opts <- modifyList(opts, list(...))
attr(opts, "class") <- c("secondary_opts", class(opts))
return(opts)
}
#' Delay Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns delay distributions formatted for usage by downstream
#' functions.
#' @param dist A delay distribution or series of delay distributions. Default is
#' a fixed distribution with all mass at 0, i.e. no delay.
#' @inheritParams generation_time_opts
#' @importFrom cli cli_abort
#' @return A `<delay_opts>` object summarising the input delay distributions.
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] \code{\link{Distributions}}
#' @export
#' @examples
#' # no delays
#' delay_opts()
#'
#' # A single delay that has uncertainty
#' delay <- LogNormal(
#' meanlog = Normal(1, 0.2),
#' sdlog = Normal(0.5, 0.1),
#' max = 14
#' )
#' delay_opts(delay)
#'
#' # A single delay without uncertainty
#' delay <- LogNormal(meanlog = 1, sdlog = 0.5, max = 14)
#' delay_opts(delay)
#'
#' # Multiple delays (in this case twice the same)
#' delay_opts(delay + delay)
delay_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
weight_prior = TRUE) {
assert_class(dist, "dist_spec")
## apply default CDF cutoff if `dist` is unconstrained
dist <- apply_default_cdf_cutoff(
dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
)
attr(dist, "weight_prior") <- weight_prior
attr(dist, "class") <- c("delay_opts", class(dist))
check_stan_delay(dist)
return(dist)
}
#' Truncation Distribution Options
#'
#' @description `r lifecycle::badge("stable")`
#' Returns a truncation distribution formatted for usage by
#' downstream functions. See [estimate_truncation()] for an approach to
#' estimate these distributions.
#'
#' @param dist A delay distribution or series of delay distributions reflecting
#' the truncation. It can be specified using the probability distributions
#' interface in `EpiNow2` (See `?EpiNow2::Distributions`) or estimated using
#' [estimate_truncation()], which returns a `dist` object, suited
#' for use here out-of-box. Default is a fixed distribution with maximum 0, i.e.
#' no truncation.
#' @param weight_prior Logical; if TRUE, the truncation prior will be weighted
#' by the number of observation data points, in doing so approximately placing
#' an independent prior at each time step and usually preventing the
#' posteriors from shifting. If FALSE (default), no weight will be applied,
#' i.e. the truncation distribution will be treated as a single parameter.
#'
#' @inheritParams gt_opts
#' @importFrom cli cli_abort
#' @return A `<trunc_opts>` object summarising the input truncation
#' distribution.
#'
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] \code{\link{Distributions}}
#' @export
#' @examples
#' # no truncation
#' trunc_opts()
#'
#' # truncation dist
#' trunc_opts(dist = LogNormal(mean = 3, sd = 2, max = 10))
trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
weight_prior = FALSE) {
assert_class(dist, "dist_spec")
## apply default CDF cutoff if `dist` is unconstrained
dist <- apply_default_cdf_cutoff(
dist, default_cdf_cutoff, !missing(default_cdf_cutoff)
)
attr(dist, "weight_prior") <- weight_prior
attr(dist, "class") <- c("trunc_opts", class(dist))
check_stan_delay(dist)
return(dist)
}
#' Time-Varying Reproduction Number Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the time-varying
#' reproduction number. Custom settings can be supplied which override the
#' defaults.
#'
#' @param prior A `<dist_spec>` giving the prior of the initial reproduciton
#' number. Ignored if `use_rt` is `FALSE`. Defaults to a LogNormal distributin
#' with mean of 1 and standard deviation of 1: `LogNormal(mean = 1, sd = 1)`.
#' A lower limit of 0 will be enforced automatically.
#'
#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate
#' infections and hence reported cases.
#'
#' @param rw Numeric step size of the random walk, defaults to 0. To specify a
#' weekly random walk set `rw = 7`. For more custom break point settings
#' consider passing in a `breakpoints` variable as outlined in the next section.
#'
#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be
#' used if present as a `breakpoint` variable in the input data. Break points
#' should be defined as 1 if present and otherwise 0. By default breakpoints
#' are fit jointly with a global non-parametric effect and so represent a
#' conservative estimate of break point changes (alter this by setting
#' `gp = NULL`).
#'
#' @param pop Integer, defaults to 0. Susceptible population initially present.
#' Used to adjust Rt estimates in the forecast horizon based on the
#' proportion of the population that is susceptible. When set to 0 no
#' population adjustment is done.
#'
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the
#' Gaussian process, if in use, should be applied to Rt. Currently supported
#' options are applying the Gaussian process to the last estimated Rt (i.e
#' Rt = Rt-1 * GP), and applying the Gaussian process to a global mean (i.e Rt
#' = R0 * GP). Both should produced comparable results when data is not sparse
#' but the method relying on a global mean will revert to this for real time
#' estimates, which may not be desirable.
#'
#' @return An `<rt_opts>` object with settings defining the time-varying
#' reproduction number.
#' @inheritParams create_future_rt
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort
#' @export
#' @examples
#' # default settings
#' rt_opts()
#'
#' # add a custom length scale
#' rt_opts(prior = LogNormal(mean = 2, sd = 1))
#'
#' # add a weekly random walk
#' rt_opts(rw = 7)
rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
use_rt = TRUE,
rw = 0,
use_breakpoints = TRUE,
future = "latest",
gp_on = c("R_t-1", "R0"),
pop = 0) {
opts <- list(
use_rt = use_rt,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
pop = pop,
gp_on = arg_match(gp_on)
)
# replace default settings with those specified by user
if (opts$rw > 0) {
opts$use_breakpoints <- TRUE
}
if (is.list(prior) && !is(prior, "dist_spec")) {
cli_abort(
c(
"!" = "Specifying {.var prior} as a list is deprecated.",
"i" = "Use a {.cls dist_spec} instead."
)
)
}
if (opts$use_rt) {
opts$prior <- prior
} else if (!missing(prior)) {
cli_warn(
c(
"!" = "Rt {.var prior} is ignored if {.var use_rt} is FALSE."
)
)
}
attr(opts, "class") <- c("rt_opts", class(opts))
return(opts)
}
#' Back Calculation Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the back calculation
#' of cases. Only used if `rt = NULL`.
#'
#' @param prior A character string defaulting to "reports". Defines the prior
#' to use when deconvolving. Currently implemented options are to use smoothed
#' mean delay shifted reported cases ("reports"), to use the estimated
#' infections from the previous time step seeded for the first time step using
#' mean shifted reported cases ("infections"), or no prior ("none"). Using no
#' prior will result in poor real time performance. No prior and using
#' infections are only supported when a Gaussian process is present . If
#' observed data is not reliable then it a sensible first step is to explore
#' increasing the `prior_window` wit a sensible second step being to no longer
#' use reported cases as a prior (i.e set `prior = "none"`).
#'
#' @param prior_window Integer, defaults to 14 days. The mean centred smoothing
#' window to apply to mean shifted reports (used as a prior during back
#' calculation). 7 days is minimum recommended settings as this smooths day of
#' the week effects but depending on the quality of the data and the amount of
#' information users wish to use as a prior (higher values equalling a less
#' informative prior).
#'
#' @param rt_window Integer, defaults to 1. The size of the centred rolling
#' average to use when estimating Rt. This must be odd so that the central
#' estimate is included.
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort
#'
#' @return A `<backcalc_opts>` object of back calculation settings.
#' @export
#' @examples
#' # default settings
#' backcalc_opts()
backcalc_opts <- function(prior = c("reports", "none", "infections"),
prior_window = 14, rt_window = 1) {
backcalc <- list(
prior = arg_match(prior),
prior_window = prior_window,
rt_window = as.integer(rt_window)
)
if (backcalc$rt_window %% 2 == 0) {
cli_abort(
c(
"!" = "{.var rt_window} must be odd in order to
include the current estimate.",
"i" = "You have supplied an even number."
)
)
}
attr(backcalc, "class") <- c("backcalc_opts", class(backcalc))
return(backcalc)
}
#' Approximate Gaussian Process Settings
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Deprecated; use `ls` instead.
#'
#' @param ls_sd Deprecated; use `ls` instead.
#'
#' @param ls_min Deprecated; use `ls` instead.
#'
#' @param ls_max Deprecated; use `ls` instead.
#'
#' @param ls A `<dist_spec>` giving the prior distribution of the lengthscale
#' parameter of the Gaussian process kernel on the scale of days. Defaults to
#' a Lognormal distribution with mean 21 days, sd 7 days and maximum 60 days:
#' `LogNormal(mean = 21, sd = 7, max = 60)` (a lower limit of 0 will be
#' enforced automatically to ensure positivity)
#'
#' @param alpha A `<dist_spec>` giving the prior distribution of the magnitude
#' parameter of the Gaussian process kernel. Should be approximately the
#' expected standard deviation of the Gaussian process (logged Rt in case of
#' the renewal model, logged infections in case of the nonmechanistic model).
#' Defaults to a half-normal distribution with mean 0 and sd 0.01:
#' `Normal(mean = 0, sd = 0.01)` (a lower limit of 0 will be enforced
#' automatically to ensure positivity)
#'
#' @param alpha_mean Deprecated; use `alpha` instead.
#'
#' @param alpha_sd Deprecated; use `alpha` instead.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the Matern kernel ("matern"), squared exponential kernel ("se"),
#' periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic
#' kernel ("periodic").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
#' to "ou", `matern_order` will be automatically set to 1/2. Only used if
#' the kernel is set to "matern".
#'
#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn
#' Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, the proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
#' accuracy but a faster compute time (with increasing it having the first
#' effect). In general smaller posterior length scales require a higher
#' proportion of basis functions. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param boundary_scale Numeric, defaults to 1.5. Boundary scale of the
#' approximate Gaussian process. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic
#' kernel. They are only used if `kernel` is set to "periodic".
#'
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort cli_warn
#' @return A `<gp_opts>` object of settings defining the Gaussian process
#' @export
#' @examples
#' # default settings
#' gp_opts()
#'
#' # add a custom length scale
#' gp_opts(ls = LogNormal(mean = 4, sd = 1, max = 20))
#'
#' # use linear kernel
#' gp_opts(kernel = "periodic")
gp_opts <- function(basis_prop = 0.2,
boundary_scale = 1.5,
ls_mean = 21,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
ls = LogNormal(mean = 21, sd = 7, max = 60),
alpha = Normal(mean = 0, sd = 0.01),
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
matern_type,
w0 = 1.0,
alpha_mean, alpha_sd) {
if (!missing(matern_type)) {
lifecycle::deprecate_stop(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
}
if (!missing(alpha_mean)) {
lifecycle::deprecate_stop(
"1.7.0", "gp_opts(alpha_mean)", "gp_opts(alpha)"
)
}
if (!missing(alpha_sd)) {
lifecycle::deprecate_stop(
"1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)"
)
}
if (!missing(ls_mean) || !missing(ls_sd) || !missing(ls_min) ||
!missing(ls_max)) {
if (!missing(ls)) {
cli_abort(
c(
"!" = "Both {.var ls} and at least one legacy argument
({.var ls_mean}, {.var ls_sd}, {.var ls_min}, {.var ls_max}) have been
specified.",
"i" = "Only one of the should be used."
)
)
}
cli_abort(c(
"!" = "Specifying lengthscale priors via the {.var ls_mean}, {.var ls_sd},
{.var ls_min}, and {.var ls_max} arguments is deprecated.",
"i" = "Use the {.var ls} argument instead."
))
if (ls_min > 0) {
cli_abort(
c(
"!" = "Lower lengthscale bounds of greater than 0 are no longer
supported. If this is a feature you need please open an Issue on the
EpiNow2 GitHub repository."
)
)
}
ls <- LogNormal(mean = ls_mean, sd = ls_sd, max = ls_max)
}
if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type != matern_order) {
cli_abort(
c(
"!" = "{.var matern_order} and {.var matern_type} must be the same, if
both are supplied.",
"i" = "Rather only use {.var matern_order} only."
)
)
}
matern_order <- matern_type
}
kernel <- arg_match(kernel)
if (kernel == "se") {
matern_order <- Inf
} else if (kernel == "ou") {
matern_order <- 1 / 2
} else if (
!(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
) {
cli_warn(
c(
"!" = "Uncommon Matern kernel order supplied.",
"i" = "Use one of `1 / 2`, `3 / 2`, or `5 / 2`" # nolint
)
)
}
gp <- list(
basis_prop = basis_prop,
boundary_scale = boundary_scale,
ls = ls,
alpha = alpha,
kernel = kernel,
matern_order = matern_order,
w0 = w0
)
attr(gp, "class") <- c("gp_opts", class(gp))
return(gp)
}
#' Observation Model Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the observation
#' model. Custom settings can be supplied which override the defaults.
#' @param family Character string defining the observation model. Options are
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param dispersion A `<dist_spec>` specifying a prior on the dispersion
#' parameter of the reporting process, used only if `familiy` is "negbin".
#' Internally parameterised such that this parameter is one over the square
#' root of the `phi` parameter for overdispersion of the
#' [negative binomial distribution](https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#neg-binom-2-log). # nolint
#' Defaults to a half-normal distribution with mean of 0 and
#' standard deviation of 0.25: `Normal(mean = 0, sd = 0.25)`. A lower limit of
#' zero will be enforced automatically.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#' effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#' @param scale A `<dist_spec>` specifying a prior on the scaling factor to be
#' applied to map latent infections (convolved to date of report). Defaults
#' to a fixed value of 1, i.e. no scaling: `Fixed(1)`. A lower limit of zero
#' will be enforced automatically. If setting to a prior distribution and no
#' overreporting is expected, it might be sensible to set a maximum of 1 via
#' the `max` option when declaring the distribution.
#' @param na Deprecated; use the [fill_missing()] function instead
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#' included in the model.
#' @param phi deprecated; use `dispersion` instead
#' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood
#' be returned by the model.
#' @importFrom rlang arg_match
#' @importFrom cli cli_inform cli_abort
#' @return An `<obs_opts>` object of observation model settings.
#' @export
#' @examples
#' # default settings
#' obs_opts()
#'
#' # Turn off day of the week effect
#' obs_opts(week_effect = TRUE)
#'
#' # Scale reported data
#' obs_opts(scale = Normal(mean = 0.2, sd = 0.02))
obs_opts <- function(family = c("negbin", "poisson"),
dispersion = Normal(mean = 0, sd = 0.25),
weight = 1,
week_effect = TRUE,
week_length = 7,
scale = Fixed(1),
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE,
phi) {
if (!missing(phi)) {
if (!missing(dispersion)) {
cli::cli_abort(
"Can't specify {.var disperion} and {.var phi}."
)
} else {
lifecycle::deprecate_stop(
"1.7.0",
"obs_opts(phi)",
"obs_opts(dispersion)",
details =
"The meaning of the `phi` and `dispersion` arguments are the same."
)
dispersion <- phi
}
}
na_default_used <- missing(na)
if (!na_default_used) {
lifecycle::deprecate_stop(
"1.7.0",
"obs_opts(na)",
"fill_missing()",
details = c(
paste0(
"If NA values are not to be treated as missing use the ",
"`fill_missing()` function instead."
),
"This argument will be removed in the next release of EpiNow2."
)
)
}
na <- arg_match(na)
if (na == "accumulate") {
# nolint start: duplicate_argument_linter
cli_inform(
c(
"i" = "Accumulating modelled values that correspond to NA values in the
data by adding them to the next non-NA data point.",
"i" = "This means that the first data point is not included in the
likelihood but used only to reset modelled observations to zero.",
"i" = "{col_red('If the first data point should be included in the
likelihood this can be achieved by using the `fill_missing()` function
with a non-zero `initial_missing` argument.')}"
),
.frequency = "regularly",
.frequency_id = "obs_opts"
)
# nolint end
}
obs <- list(
family = arg_match(family),
dispersion = dispersion,
weight = weight,
week_effect = week_effect,
week_length = week_length,
scale = scale,
accumulate = as.integer(na == "accumulate"),
likelihood = likelihood,
return_likelihood = return_likelihood,
na_as_missing_default_used = na_default_used
)
for (param in c("dispersion", "scale")) {
if (is.numeric(obs[[param]])) {
cli_abort(
c(
"!" = "Specifying {.var {param}} as a numeric value is deprecated.",
"i" = "Use a {.cls dist_spec} instead using {.fn Fixed()}."
)
)
obs[[param]] <- Fixed(obs[[param]])
} else if (is.list(obs[[param]]) && !is(obs[[param]], "dist_spec")) {
cli_abort(
c(
"!" = "Specifying {.var {param}} as a list is deprecated.",
"i" = "Use a {.cls dist_spec} instead."
)
)
obs[[param]] <- Normal(mean = obs[[param]]$mean, sd = obs[[param]]$sd)
} else {
assert_class(obs[[param]], "dist_spec")
}
}
attr(obs, "class") <- c("obs_opts", class(obs))
return(obs)
}
#' Stan Sampling Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to either [rstan::sampling()]
#' or [cmdstanr::sample()]. Custom settings can be supplied which override the
#' defaults.
#'
#' @param cores Number of cores to use when executing the chains in parallel,
#' which defaults to 1 but it is recommended to set the mc.cores option to be
#' as many processors as the hardware and RAM allow (up to the number of
#' chains).
#'
#' @param warmup Numeric, defaults to 250. Number of warmup samples per chain.
#'
#' @param samples Numeric, default 2000. Overall number of posterior samples.
#' When using multiple chains iterations per chain is samples / chains.
#'
#' @param chains Numeric, defaults to 4. Number of MCMC chains to use.
#'
#' @param control List, defaults to empty. control parameters to pass to
#' underlying `rstan` function. By default `adapt_delta = 0.9` and
#' `max_treedepth = 12` though these settings can be overwritten.
#'
#' @param save_warmup Logical, defaults to FALSE. Should warmup progress be
#' saved.
#'
#' @param seed Numeric, defaults uniform random number between 1 and 1e8. Seed
#' of sampling process.
#'
#' @param future Logical, defaults to `FALSE`. Should stan chains be run in
#' parallel using `future`. This allows users to have chains fail gracefully
#' (i.e when combined with `max_execution_time`). Should be combined with a
#' call to [future::plan()].
#'
#' @param max_execution_time Numeric, defaults to Inf (seconds). If set wil
#' kill off processing of each chain if not finished within the specified
#' timeout. When more than 2 chains finish successfully estimates will still be
#' returned. If less than 2 chains return within the allowed time then
#' estimation will fail with an informative error.
#'
#' @inheritParams stan_opts
#'
#' @param ... Additional parameters to pass to [rstan::sampling()] or
#' [cmdstanr::sample()].
#' @importFrom utils modifyList
#' @importFrom cli cli_warn
#' @return A list of arguments to pass to [rstan::sampling()] or
#' [cmdstanr::sample()].
#' @export
#' @examples
#' stan_sampling_opts(samples = 2000)
stan_sampling_opts <- function(cores = getOption("mc.cores", 1L),
warmup = 250,
samples = 2000,
chains = 4,
control = list(),
save_warmup = FALSE,
seed = as.integer(runif(1, 1, 1e8)),
future = FALSE,
max_execution_time = Inf,
backend = c("rstan", "cmdstanr"),
...) {
dot_args <- list(...)
backend <- arg_match(backend)
opts <- list(
chains = chains,
save_warmup = save_warmup,
seed = seed,
future = future,
max_execution_time = max_execution_time
)
control_def <- list(adapt_delta = 0.9, max_treedepth = 12)
control_def <- modifyList(control_def, control)
if (any(c("iter", "iter_sampling") %in% names(dot_args))) {
cli_warn(
c(
"!" = "Number of samples must be specified using the {.var samples}
and {.var warmup} arguments rather than {.var iter} or
{.var iter_sampliing}.",
"i" = "Supplied {.var iter} or {.var iter_sampliing} will be ignored."
)
)
}
dot_args$iter <- NULL
dot_args$iter_sampling <- NULL
if (backend == "rstan") {
opts <- c(opts, list(
cores = cores,
warmup = warmup,
control = control_def,
iter = ceiling(samples / opts$chains) + warmup
))
} else if (backend == "cmdstanr") {
opts <- c(opts, list(
parallel_chains = cores,
iter_warmup = warmup,
iter_sampling = ceiling(samples / opts$chains)
), control_def)
}
opts <- c(opts, dot_args)
return(opts)
}
#' Stan Variational Bayes Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to [rstan::vb()] or
#' [cmdstanr::variational()]. Custom settings can be supplied which override the
#' defaults.
#'
#' @param samples Numeric, default 2000. Overall number of approximate posterior
#' samples.
#'
#' @param trials Numeric, defaults to 10. Number of attempts to use
#' rstan::vb()] before failing.
#'
#' @param iter Numeric, defaulting to 10000. Number of iterations to use in
#' [rstan::vb()].
#'
#' @param ... Additional parameters to pass to [rstan::vb()] or
#' [cmdstanr::variational()], depending on the chosen backend.
#'
#' @return A list of arguments to pass to [rstan::vb()] or
#' [cmdstanr::variational()], depending on the chosen backend.
#' @export
#' @examples
#' stan_vb_opts(samples = 1000)
stan_vb_opts <- function(samples = 2000,
trials = 10,
iter = 10000, ...) {
opts <- list(
trials = trials,
iter = iter,
output_samples = samples
)
opts <- c(opts, ...)
return(opts)
}
#' Stan Laplace algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @importFrom cli cli_abort col_blue
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_laplace_opts <- function(backend = "cmdstanr",
trials = 10,
...) {
if (backend != "cmdstanr") {
cli_abort(
c(
"!" = "Backend must be set to {col_blue(\"cmdstanr\")} to use
the Laplace algorithm.",
"i" = "Change {.var backend} to col_blue(\"cmdstanr\")}."
)
)
}
opts <- list(trials = trials)
opts <- c(opts, ...)
return(opts)
}
#' Stan pathfinder algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @importFrom cli cli_abort col_blue
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_pathfinder_opts <- function(backend = "cmdstanr",
samples = 2000,
trials = 10,
...) {
if (backend != "cmdstanr") {
cli_abort(
c(
"!" = "Backend must be set to {col_blue(\"cmdstanr\")} to use
the pathfinder algorithm.",
"i" = "Change {.var backend} to col_blue(\"cmdstanr\")}."
)
)
}
opts <- list(
trials = trials,
draws = samples
)
opts <- c(opts, ...)
return(opts)
}
#' Stan Options
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to underlying stan
#' backend functions via [stan_sampling_opts()] and [stan_vb_opts()]. Custom
#' settings can be supplied which override the defaults.
#'
#' @param object Stan model object. By default uses the compiled package
#' default if using the "rstan" backend, and the default model obtained using
#' [epinow2_cmdstan_model()] if using the "cmdstanr" backend.
#'
#' @param samples Numeric, defaults to 2000. Number of posterior samples.
#' @param method A character string, defaulting to sampling. Currently supports
#' MCMC sampling ("sampling") or approximate posterior sampling via
#' variational inference ("vb") and, as experimental features if the
#' "cmdstanr" backend is used, approximate posterior sampling with the
#' laplace algorithm ("laplace") or pathfinder ("pathfinder").
#'
#' @param backend Character string indicating the backend to use for fitting
#' stan models. Supported arguments are "rstan" (default) or "cmdstanr".
#'
#' @param return_fit Logical, defaults to TRUE. Should the fit stan model be
#' returned.
#'
#' @param ... Additional parameters to pass to underlying option functions,
#' [stan_sampling_opts()] or [stan_vb_opts()], depending on the method
#'
#' @importFrom rlang arg_match
#' @importFrom cli cli_abort cli_warn col_blue
#' @return A `<stan_opts>` object of arguments to pass to the appropriate
#' rstan functions.
#' @export
#' @seealso [stan_sampling_opts()] [stan_vb_opts()]
#' @examples
#' # using default of [rstan::sampling()]
#' stan_opts(samples = 1000)
#'
#' # using vb
#' stan_opts(method = "vb")
stan_opts <- function(object = NULL,
samples = 2000,
method = c("sampling", "vb", "laplace", "pathfinder"),
backend = c("rstan", "cmdstanr"),
return_fit = TRUE,
...) {
method <- arg_match(method)
backend_passed <- !missing(backend)
backend <- arg_match(backend)
if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) {
cli_abort(
c(
"x" = "The {col_blue('cmdstanr')} R package is not installed.",
"i" = "Install it from {.url https://github.com/stan-dev/cmdstanr}
to use the {col_blue('cmdstanr')} backend."
)
)
}
opts <- list()
if (!is.null(object)) {
if (backend_passed) {
cli_warn(
c(
"!" = "{.var backend} option will be ignored as a stan model
object has been passed."
)
)
}
if (inherits(object, "stanmodel")) {
backend <- "rstan"
} else if (inherits(object, "CmdStanModel")) {
backend <- "cmdstanr"
} else {
cli_abort(
c(
"!" = "{.var object} must be a stan model object."
)
)
}
} else {
backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
opts <- c(opts, list(backend = backend))
}
opts <- c(opts, list(
object = object,
method = method
))
opts <- switch(method,
sampling = c(
opts, stan_sampling_opts(samples = samples, backend = backend, ...)
),
vb = c(
opts, stan_vb_opts(samples = samples, ...)
),
laplace = c(
opts, stan_laplace_opts(backend = backend, ...)
),
pathfinder = c(
opts, stan_pathfinder_opts(samples = samples, backend = backend, ...)
)
)
opts <- c(opts, list(return_fit = return_fit))
attr(opts, "class") <- c("stan_opts", class(opts))
return(opts)
}
#' Forecast options
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the arguments passed to underlying stan
#' backend functions via [stan_sampling_opts()] and [stan_vb_opts()]. Custom
#' settings can be supplied which override the defaults.
#'
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#' @param accumulate Integer, the number of days to accumulate in forecasts, if
#' any. If not given and observations are accumulated at constant frequency in
#' the data used for fitting then the same accumulation will be used in
#' forecasts unless set explicitly here.
#' @return A `<forecast_opts>` object of forecast setting.
#' @seealso fill_missing
#' @export
#' @examples
#' forecast_opts(horizon = 28, accumulate = 7)
forecast_opts <- function(horizon = 7, accumulate) {
opts <- list(
horizon = horizon
)
if (!missing(accumulate)) {
opts$accumulate <- accumulate
}
attr(opts, "class") <- c("forecast_opts", class(opts))
return(opts)
}
#' Forecast optiong
#'
#' @description `r lifecycle::badge("maturing")`
#' Define a list of `_opts()` to pass to [regional_epinow()] `_opts()` accepting
#' arguments. This is useful when different settings are needed between regions
#' within a single [regional_epinow()] call. Using [opts_list()] the defaults
#' can be applied to all regions present with an override passed to regions as
#' necessary (either within [opts_list()] or externally).
#'
#' @param opts An `_opts()` function call such as [rt_opts()].
#'
#' @param reported_cases A data frame containing a `region` variable
#' indicating the target regions.
#'
#' @param ... Optional override for region defaults. See the examples
#' for use case.
#'
#' @importFrom purrr list_assign
#'
#' @return A named list of options per region which can be passed to the `_opt`
#' accepting arguments of `regional_epinow`.
#' @seealso [regional_epinow()] [rt_opts()]
#' @export
#' @examples
#' # uses example case vector
#' cases <- example_confirmed[1:40]
#' cases <- data.table::rbindlist(list(
#' data.table::copy(cases)[, region := "testland"],
#' cases[, region := "realland"]
#' ))
#'
#' # default settings
#' opts_list(rt_opts(), cases)
#'
#' # add a weekly random walk in realland
#' opts_list(rt_opts(), cases, realland = rt_opts(rw = 7))
#'
#' # add a weekly random walk externally
#' rt <- opts_list(rt_opts(), cases)
#' rt$realland$rw <- 7
#' rt
opts_list <- function(opts, reported_cases, ...) {
regions <- unique(reported_cases$region)
default <- rep(list(opts), length(regions))
names(default) <- regions
list_assign(default, ...)
}
#' Filter Options for a Target Region
#'
#' @description `r lifecycle::badge("maturing")`
#' A helper function that allows the selection of region specific settings if
#' present and otherwise applies the overarching settings.
#'
#' @param opts Either a list of calls to an `_opts()` function or a single
#' call to an `_opts()` function.
#'
#' @param region A character string indicating a region of interest.
#'
#' @return A list of options
filter_opts <- function(opts, region) {
if (region %in% names(opts)) {
out <- opts[[region]]
} else {
out <- opts
}
out
}
#' Apply default CDF cutoff to a <dist_spec> if it is unconstrained
#'
#' @param dist A <dist_spec>
#' @param default_cdf_cutoff Numeric; default CDF cutoff to be used if an
#' unconstrained distribution is passed as `dist`. If `dist` is already
#' constrained by having a maximum or CDF cutoff this is ignored. Note that
#' this can only be done for <dist_spec> objects with fixed parameters.
#' @param cdf_cutoff_set Logical; whether the default CDF cutoff has been set by
#' the user; if yes and `dist` is constrained a warning is issued
#' @importFrom cli cli_inform cli_warn
#'
#' @return A <dist_spec> with the default CDF cutoff set if previously not
#' constrained
#' @keywords internal
apply_default_cdf_cutoff <- function(dist, default_cdf_cutoff, cdf_cutoff_set) {
if (!is_constrained(dist) && !anyNA(sd(dist))) {
# nolint start: duplicate_argument_linter
cli_inform(
c(
"i" = "Unconstrained distributon passed as a delay. ",
"i" = "Constraining with default CDF cutoff {default_cdf_cutoff}.",
"i" = "To silence this message, specify delay distributions
with {.var max} or {.var default_cdf_cutoff}."
)
)
# nolint end
attr(dist, "cdf_cutoff") <- default_cdf_cutoff
} else if (cdf_cutoff_set) {
cli_warn(
c(
"!" = "Ignoring given default CDF cutoff.",
"i" = "Distribution is already constrained."
)
)
}
dist
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.