#' Run a pmcmc sampler with the Squire model setup (i.e. include the various model parameters for the odin model to generate curves)
#'
#' @title Run a Particle MCMC Sampler within the Squire Framework
#'
#' @param data Data to fit to. This must be constructed with
#' \code{particle_filter_data}
#' @param squire_model A squire model to use
#' @param pars_obs list of parameters to use in comparison
#' with \code{compare}. Must be a list containing, e.g.
#' list(phi_cases = 0.1,
#' k_cases = 2,
#' phi_death = 1,
#' k_death = 2,
#' exp_noise = 1e6)
#' @param n_mcmc number of mcmc mcmc iterations to perform
#' @param pars_init named list of initial inputs for parameters being sampled
#' @param pars_min named list of lower reflecting boundaries for parameter proposals
#' @param pars_max named list of upper reflecting boundaries for parameter proposals
#' @param proposal_kernel named matrix of proposal covariance for parameters
#' @param scaling_factor numeric for starting scaling factor for covariance matrix. Default = 1
#' @param pars_discrete named list of logicals, indicating if proposed jump should be discrete
#' @param log_likelihood function to calculate log likelihood, must take named parameter vector as input,
#' allow passing of implicit arguments corresponding to the main function arguments.
#' Returns a named list, with entries:
#' - $log_likelihood, a single numeric
#' - $sample_state, a numeric vector corresponding to the state of a single particle, chosen at random,
#' at the final time point for which we have data.
#' If NULL, calculated using the function calc_loglikelihood.
#' @param log_prior function to calculate log prior, must take named parameter vector as input, returns a single numeric.
#' If NULL, uses uninformative priors which do not affect the posterior
#' @param n_particles Number of particles (considered for both the PMCMC fit and sampling from posterior)
#' @param steps_per_day Number of steps per day
#' @param output_proposals Logical indicating whether proposed parameter jumps should be output along with results
#' @param n_chains number of MCMC chains to run
#' @inheritParams calibrate
#' @param date_vaccine_change Date that vaccine doses per day change.
#' Default = NULL.
#' @param baseline_max_vaccine Baseline vaccine doses per day. Default = NULL
#' @param max_vaccine Time varying maximum vaccine doeses per day. Default = NULL.
#' @param date_vaccine_efficacy_infection_change Date that vaccine efficacy
#' against infection changes. Default = NULL.
#' @param baseline_vaccine_efficacy_infection Baseline vaccine effacy against infection.
#' Default = NULL
#' @param vaccine_efficacy_infection Time varying vaccine efficacy against infection.
#' Default = NULL.
#' @param date_vaccine_efficacy_disease_change Date that vaccine efficacy
#' against disease changes. Default = NULL.
#' @param baseline_vaccine_efficacy_disease Baseline vaccine efficacy against disease
#' Default = NULL
#' @param vaccine_efficacy_disease Time varying vaccine efficacy against infection.
#' Default = NULL.
#' @param Rt_args List of arguments to be passed to \code{evaluate_Rt_pmcmc} for calculating Rt.
#' Current arguments are available in \code{Rt_args_list}
#' @param burnin number of iterations to discard from the start of MCMC run when sampling from the posterior for trajectories
#' @param replicates number of trajectories (replicates) to be returned that are being sampled from the posterior probability results produced by \code{run_mcmc_chain}
#' to select parameter set. For each parmater set sampled, run particle filter with \code{n_particles} and sample 1 trajectory
#' @param forecast Number of days to forecast forward. Default = 0
#' @param required_acceptance_ratio Desired MCMC acceptance ratio
#' @param start_adaptation Iteration number to begin RM optimisation of scaling factor at
#' @param gibbs_sampling Whether or not to use the Gibbs Sampler for start_date
#' @param gibbs_days Number of days either side of the start_date parameter to evaluate likelihood at
#' @param ... Further aguments for the model parameter function. If using the
#' \code{\link{explicit_model}} (default) this will be
#' \code{parameters_explicit_SEEIR}.
#'
#' @return squire_simulation \describe{
#' \item{output}{Trajectories from the sampled pMCMC parameter iterations.}
#' \item{parameters}{Model parameters use for squire}
#' \item{model}{Squire model used}
#' \item{inputs}{Inputs into the squire model for the pMCMC.}
#' \item{pMCMC_results}{An mcmc object generated from \code{pmcmc} and contains:}
#' \describe{
#' \item{inputs}{List of inputs}
#' \item{chains}{List that include}:
#' \describe{
#' \item{results}{Matrix of accepted parameter samples, rows = iterations
#' as well as log prior, (particle filter estimate of) log likelihood and log posterior}
#' \item{states}{Matrix of compartment states}
#' \item{acceptance_rate}{MCMC acceptance rate}
#' \item{ess}{MCMC chain effective sample size}
#' }
#' \item{rhat}{MCMC Diagnostics}
#' }
#' \item{interventions}{Contains the interventions that can be called with projections}.
#' \item{replicate_parameters}{contains the parameter values for the sampled pMCMC parameter iterations
#' used to generate the \code{squire_model} trajectories}
#'}
#'
#' @description The user inputs initial parameter values for R0, Meff, and the start date
#' The log prior likelihood of these parameters is calculated based on the user-defined
#' prior distributions.
#' The log likelihood of the data given the initial parameters is estimated using a particle filter,
#' which has two functions:
#' - Firstly, to generate a set of 'n_particles' samples of the model state space,
#' at time points corresponding to the data, one of which is
#' selected randomly to serve as the proposed state sequence sample at the final
#' data time point.
#' - Secondly, to produce an unbiased estimate of the likelihood of the data given the proposed parameters.
#' The log posterior of the initial parameters given the data is then estimated by adding the log prior and
#' log likelihood estimate.
#'
#' The pMCMC sampler then proceeds as follows, for n_mcmc iterations:
#' At each loop iteration the pMCMC sampler pefsorms three steps:
#' 1. Propose new candidate samples for R0, Meff, Meff_pl, and start_date based on
#' the current samples, using the proposal distribution
#' (currently multivariate Gaussian with user-input covariance matrix (proposal_kernel), and reflecting boundaries defined by pars_min, pars_max)
#' 2. Calculate the log prior of the proposed parameters,
#' Use the particle filter to estimate log likelihood of the data given the proposed parameters, as described above,
#' as well as proposing a model state space.
#' Add the log prior and log likelihood estimate to estimate the log posterior of the proposed parameters given the data.
#' 3. Metropolis-Hastings step: The joint canditate sample (consisting of the proposed parameters
#' and state space) is then accepted with probability min(1, a), where the acceptance ratio is
#' simply the ratio of the posterior likelihood of the proposed parameters to the posterior likelihood
#' of the current parameters. Note that by choosing symmetric proposal distributions by including
#' reflecting boundaries, we avoid the the need to include the proposal likelihood in the MH ratio.
#'
#' If the proposed parameters and states are accepted then we update the current parameters and states
#' to match the proposal, otherwise the previous parameters/states are retained for the next iteration.
#'
#' After generating the pMCMC simulation, there are \code{replicates} specific iterations sampled based on the
#' posterior probability. The parameters from those iterations are then used to generate new trajectories within
#' the \code{squire_model} framework.
#'
#' @export
#' @import coda
#' @importFrom stats rnorm plogis qnorm cov median
#' @importFrom mvtnorm rmvnorm
pmcmc <- function(data,
n_mcmc,
log_likelihood = NULL,
log_prior = NULL,
n_particles = 1e2,
steps_per_day = 4,
output_proposals = FALSE,
n_chains = 1,
squire_model = explicit_model(),
pars_obs = list(phi_cases = 1,
k_cases = 2,
phi_death = 1,
k_death = 2,
exp_noise = 1e6),
pars_init = list('start_date' = as.Date("2020-02-07"),
'R0' = 2.5,
'Meff' = 2,
'Meff_pl' = 3,
"R0_pl_shift" = 0),
pars_min = list('start_date' = as.Date("2020-02-01"),
'R0' = 0,
'Meff' = 1,
'Meff_pl' = 2,
"R0_pl_shift" = -2),
pars_max = list('start_date' = as.Date("2020-02-20"),
'R0' = 5,
'Meff' = 3,
'Meff_pl' = 4,
"R0_pl_shift" = 5),
pars_discrete = list('start_date' = TRUE,
'R0' = FALSE,
'Meff' = FALSE,
'Meff_pl' = FALSE,
"R0_pl_shift" = FALSE),
proposal_kernel = NULL,
scaling_factor = 1,
reporting_fraction = 1,
treated_deaths_only = FALSE,
country = NULL,
population = NULL,
contact_matrix_set = NULL,
baseline_contact_matrix = NULL,
date_contact_matrix_set_change = NULL,
R0_change = NULL,
date_R0_change = NULL,
hosp_bed_capacity = NULL,
baseline_hosp_bed_capacity = NULL,
date_hosp_bed_capacity_change = NULL,
ICU_bed_capacity = NULL,
baseline_ICU_bed_capacity = NULL,
date_ICU_bed_capacity_change = NULL,
date_vaccine_change = NULL,
baseline_max_vaccine = NULL,
max_vaccine = NULL,
date_vaccine_efficacy_infection_change = NULL,
baseline_vaccine_efficacy_infection = NULL,
vaccine_efficacy_infection = NULL,
date_vaccine_efficacy_disease_change = NULL,
baseline_vaccine_efficacy_disease = NULL,
vaccine_efficacy_disease = NULL,
Rt_args = NULL,
burnin = 0,
replicates = 100,
forecast = 0,
required_acceptance_ratio = 0.23,
start_adaptation = round(n_mcmc/2),
gibbs_sampling = FALSE,
gibbs_days = NULL,
...
) {
#------------------------------------------------------------
# Section 1 of pMCMC Wrapper: Checks & Setup
#------------------------------------------------------------
#--------------------
# assertions & checks
#--------------------
# if nimue keep to 1 step per day
if(inherits(squire_model, "nimue_model")) {
steps_per_day <- 1
}
# we work with pars_init being a list of inital conditions for starting
if(any(c("start_date", "R0") %in% names(pars_init))) {
pars_init <- list(pars_init)
}
# make it same length as chains, which allows us to pass in multiple starting points
if(length(pars_init) != n_chains) {
pars_init <- rep(pars_init, n_chains)
pars_init <- pars_init[seq_len(n_chains)]
}
# data assertions
assert_dataframe(data)
assert_in("date", names(data))
assert_in("deaths", names(data))
assert_date(data$date)
assert_increasing(as.numeric(as.Date(data$date)),
message = "Dates must be in increasing order")
# check input pars df
assert_list(pars_init)
assert_list(pars_init[[1]])
assert_list(pars_min)
assert_list(pars_max)
assert_list(pars_discrete)
assert_eq(names(pars_init[[1]]), names(pars_min))
assert_eq(names(pars_min), names(pars_max))
assert_eq(names(pars_max), names(pars_discrete))
assert_in(c("R0", "start_date"),names(pars_init[[1]]),
message = "Params to infer must include R0, start_date")
assert_date(pars_init[[1]]$start_date)
assert_date(pars_min$start_date)
assert_date(pars_max$start_date)
if (pars_max$start_date >= as.Date(data$date[1])-1) {
stop("Maximum start date must be at least 2 days before the first date in data")
}
# check date variables are as Date class
for(i in seq_along(pars_init)) {
pars_init[[i]]$start_date <- as.Date(pars_init[[i]]$start_date)
}
pars_min$start_date <- as.Date(pars_min$start_date)
pars_max$start_date <- as.Date(pars_max$start_date)
# check bounds
for(var in names(pars_init[[1]])) {
assert_bounded(as.numeric(pars_init[[1]][[var]]),
left = as.numeric(pars_min[[var]]),
right = as.numeric(pars_max[[var]]),
name = paste(var, "init"))
assert_single_numeric(as.numeric(pars_min[[var]]), name = paste(var, "min"))
assert_single_numeric(as.numeric(pars_max[[var]]), name = paste(var, "max"))
assert_single_numeric(as.numeric(pars_init[[1]][[var]]), name = paste(var, "init"))
}
# additonal checks that R0 is positive as undefined otherwise
assert_pos(pars_min$R0)
assert_pos(pars_max$R0)
assert_pos(pars_init[[1]]$R0)
assert_bounded(pars_init[[1]]$R0, left = pars_min$R0, right = pars_max$R0)
# check proposal kernel
assert_matrix(proposal_kernel)
if (gibbs_sampling) {
assert_eq(colnames(proposal_kernel), names(pars_init[[1]][-1]))
assert_eq(rownames(proposal_kernel), names(pars_init[[1]][-1]))
} else {
assert_eq(colnames(proposal_kernel), names(pars_init[[1]]))
assert_eq(rownames(proposal_kernel), names(pars_init[[1]]))
}
# check likelihood items
if ( !(is.null(log_likelihood) | inherits(log_likelihood, "function")) ) {
stop("Log Likelihood (log_likelihood) must be null or a user specified function")
}
if ( !(is.null(log_prior) | inherits(log_prior, "function")) ) {
stop("Log Likelihood (log_likelihood) must be null or a user specified function")
}
assert_logical(unlist(pars_discrete))
assert_list(pars_obs)
assert_in(c("phi_cases", "k_cases", "phi_death", "k_death", "exp_noise"), names(pars_obs))
assert_numeric(unlist(pars_obs[c("phi_cases", "k_cases", "phi_death", "k_death", "exp_noise")]))
# mcmc items
assert_pos_int(n_mcmc)
assert_pos_int(n_chains)
assert_pos_int(n_particles)
assert_logical(output_proposals)
# squire and odin
assert_custom_class(squire_model, "squire_model")
assert_pos_int(steps_per_day)
assert_numeric(reporting_fraction)
assert_bounded(reporting_fraction, 0, 1, inclusive_left = FALSE, inclusive_right = TRUE)
assert_pos_int(replicates)
# date change items
assert_same_length(R0_change, date_R0_change)
# checks that dates are not in the future compared to our data
if (!is.null(date_R0_change)) {
assert_date(date_R0_change)
if(as.Date(tail(date_R0_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_R0_change is greater than the last date in data")
}
}
# ------------------------------------
# checks on odin interacting variables
# ------------------------------------
if(!is.null(contact_matrix_set)) {
assert_list(contact_matrix_set)
}
assert_same_length(contact_matrix_set, date_contact_matrix_set_change)
assert_same_length(ICU_bed_capacity, date_ICU_bed_capacity_change)
assert_same_length(hosp_bed_capacity, date_hosp_bed_capacity_change)
assert_same_length(max_vaccine, date_vaccine_change)
assert_same_length(vaccine_efficacy_infection, date_vaccine_efficacy_infection_change)
assert_same_length(vaccine_efficacy_disease, date_vaccine_efficacy_disease_change)
# handle contact matrix changes
if(!is.null(date_contact_matrix_set_change)) {
assert_date(date_contact_matrix_set_change)
assert_list(contact_matrix_set)
if(is.null(baseline_contact_matrix)) {
stop("baseline_contact_matrix can't be NULL if date_contact_matrix_set_change is provided")
}
if(as.Date(tail(date_contact_matrix_set_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_contact_matrix_set_change is greater than the last date in data")
}
# Get in correct format
if(is.matrix(baseline_contact_matrix)) {
baseline_contact_matrix <- list(baseline_contact_matrix)
}
tt_contact_matrix <- c(0, seq_len(length(date_contact_matrix_set_change)))
contact_matrix_set <- append(baseline_contact_matrix, contact_matrix_set)
} else {
tt_contact_matrix <- 0
contact_matrix_set <- baseline_contact_matrix
}
# handle ICU changes
if(!is.null(date_ICU_bed_capacity_change)) {
assert_date(date_ICU_bed_capacity_change)
assert_vector(ICU_bed_capacity)
assert_numeric(ICU_bed_capacity)
if(is.null(baseline_ICU_bed_capacity)) {
stop("baseline_ICU_bed_capacity can't be NULL if date_ICU_bed_capacity_change is provided")
}
assert_numeric(baseline_ICU_bed_capacity)
if(as.Date(tail(date_ICU_bed_capacity_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_ICU_bed_capacity_change is greater than the last date in data")
}
tt_ICU_beds <- c(0, seq_len(length(date_ICU_bed_capacity_change)))
ICU_bed_capacity <- c(baseline_ICU_bed_capacity, ICU_bed_capacity)
} else {
tt_ICU_beds <- 0
ICU_bed_capacity <- baseline_ICU_bed_capacity
}
# handle vaccine changes
if(!is.null(date_vaccine_change)) {
assert_date(date_vaccine_change)
assert_vector(max_vaccine)
assert_numeric(max_vaccine)
assert_numeric(baseline_max_vaccine)
if(is.null(baseline_max_vaccine)) {
stop("baseline_max_vaccine can't be NULL if date_vaccine_change is provided")
}
if(as.Date(tail(date_vaccine_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_vaccine_change is greater than the last date in data")
}
tt_vaccine <- c(0, seq_len(length(date_vaccine_change)))
max_vaccine <- c(baseline_max_vaccine, max_vaccine)
} else {
tt_vaccine <- 0
if(!is.null(baseline_max_vaccine)) {
max_vaccine <- baseline_max_vaccine
} else {
max_vaccine <- 0
}
}
# handle vaccine efficacy disease changes
if(!is.null(date_vaccine_efficacy_infection_change)) {
assert_date(date_vaccine_efficacy_infection_change)
if(!is.list(vaccine_efficacy_infection)) {
vaccine_efficacy_infection <- list(vaccine_efficacy_infection)
}
assert_vector(vaccine_efficacy_infection[[1]])
assert_numeric(vaccine_efficacy_infection[[1]])
assert_numeric(baseline_vaccine_efficacy_infection)
if(is.null(baseline_vaccine_efficacy_infection)) {
stop("baseline_vaccine_efficacy_infection can't be NULL if date_vaccine_efficacy_infection_change is provided")
}
if(as.Date(tail(date_vaccine_efficacy_infection_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_vaccine_efficacy_infection_change is greater than the last date in data")
}
tt_vaccine_efficacy_infection <- c(0, seq_len(length(date_vaccine_efficacy_infection_change)))
vaccine_efficacy_infection <- c(list(baseline_vaccine_efficacy_infection), vaccine_efficacy_infection)
} else {
tt_vaccine_efficacy_infection <- 0
if(!is.null(baseline_vaccine_efficacy_infection)) {
vaccine_efficacy_infection <- baseline_vaccine_efficacy_infection
} else {
vaccine_efficacy_infection <- rep(0.8, 17)
}
}
# handle vaccine efficacy disease changes
if(!is.null(date_vaccine_efficacy_disease_change)) {
assert_date(date_vaccine_efficacy_disease_change)
if(!is.list(vaccine_efficacy_disease)) {
vaccine_efficacy_disease <- list(vaccine_efficacy_disease)
}
assert_vector(vaccine_efficacy_disease[[1]])
assert_numeric(vaccine_efficacy_disease[[1]])
assert_numeric(baseline_vaccine_efficacy_disease)
if(is.null(baseline_vaccine_efficacy_disease)) {
stop("baseline_vaccine_efficacy_disease can't be NULL if date_vaccine_efficacy_disease_change is provided")
}
if(as.Date(tail(date_vaccine_efficacy_disease_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_vaccine_efficacy_disease_change is greater than the last date in data")
}
tt_vaccine_efficacy_disease <- c(0, seq_len(length(date_vaccine_efficacy_disease_change)))
vaccine_efficacy_disease <- c(list(baseline_vaccine_efficacy_disease), vaccine_efficacy_disease)
} else {
tt_vaccine_efficacy_disease <- 0
if(!is.null(baseline_vaccine_efficacy_disease)) {
vaccine_efficacy_disease <- baseline_vaccine_efficacy_disease
} else {
vaccine_efficacy_disease <- rep(0.95, 17)
}
}
# handle hosp bed changed
if(!is.null(date_hosp_bed_capacity_change)) {
assert_date(date_hosp_bed_capacity_change)
assert_vector(hosp_bed_capacity)
assert_numeric(hosp_bed_capacity)
if(is.null(baseline_hosp_bed_capacity)) {
stop("baseline_hosp_bed_capacity can't be NULL if date_hosp_bed_capacity_change is provided")
}
assert_numeric(baseline_hosp_bed_capacity)
if(as.Date(tail(date_hosp_bed_capacity_change,1)) > as.Date(tail(data$date, 1))) {
stop("Last date in date_hosp_bed_capacity_change is greater than the last date in data")
}
tt_hosp_beds <- c(0, seq_len(length(date_hosp_bed_capacity_change)))
hosp_bed_capacity <- c(baseline_hosp_bed_capacity, hosp_bed_capacity)
} else {
tt_hosp_beds <- 0
hosp_bed_capacity <- baseline_hosp_bed_capacity
}
#----------------
# Generate Odin items
#----------------
# make the date definitely a date
data$date <- as.Date(as.character(data$date))
# adjust for reporting fraction
pars_obs$phi_cases <- reporting_fraction
pars_obs$phi_death <- reporting_fraction
pars_obs$treated_deaths_only <- treated_deaths_only
# build model parameters
model_params <- squire_model$parameter_func(
country = country,
population = population,
dt = 1/steps_per_day,
contact_matrix_set = contact_matrix_set,
tt_contact_matrix = tt_contact_matrix,
hosp_bed_capacity = hosp_bed_capacity,
tt_hosp_beds = tt_hosp_beds,
ICU_bed_capacity = ICU_bed_capacity,
tt_ICU_beds = tt_ICU_beds,
max_vaccine = max_vaccine,
tt_vaccine = tt_vaccine,
vaccine_efficacy_infection = vaccine_efficacy_infection,
tt_vaccine_efficacy_infection = tt_vaccine_efficacy_infection,
vaccine_efficacy_disease = vaccine_efficacy_disease,
tt_vaccine_efficacy_disease = tt_vaccine_efficacy_disease,
...)
# collect interventions for odin model likelihood
interventions <- list(
R0_change = R0_change,
date_R0_change = date_R0_change,
date_contact_matrix_set_change = date_contact_matrix_set_change,
contact_matrix_set = contact_matrix_set,
date_ICU_bed_capacity_change = date_ICU_bed_capacity_change,
ICU_bed_capacity = ICU_bed_capacity,
date_hosp_bed_capacity_change = date_hosp_bed_capacity_change,
hosp_bed_capacity = hosp_bed_capacity,
date_vaccine_change = date_vaccine_change,
max_vaccine = max_vaccine,
date_vaccine_efficacy_disease_change = date_vaccine_efficacy_disease_change,
vaccine_efficacy_disease = vaccine_efficacy_disease,
date_vaccine_efficacy_infection_change = date_vaccine_efficacy_infection_change,
vaccine_efficacy_infection = vaccine_efficacy_infection
)
#----------------..
# Collect Odin and MCMC Inputs
#----------------..
inputs <- list(
data = data,
n_mcmc = n_mcmc,
model_params = model_params,
interventions = interventions,
pars_obs = pars_obs,
Rt_args = Rt_args,
squire_model = squire_model,
pars = list(pars_obs = pars_obs,
pars_init = pars_init,
pars_min = pars_min,
pars_max = pars_max,
proposal_kernel = proposal_kernel,
scaling_factor = scaling_factor,
pars_discrete = pars_discrete),
n_particles = n_particles)
#----------------
# create prior and likelihood functions given the inputs
#----------------
if(is.null(log_prior)) {
# set improper, uninformative prior
log_prior <- function(pars) log(1e-10)
}
calc_lprior <- log_prior
if(is.null(log_likelihood)) {
log_likelihood <- calc_loglikelihood
} else if (!('...' %in% names(formals(log_likelihood)))){
stop('log_likelihood function must be able to take unnamed arguments')
}
# create shorthand function to calc_ll given main inputs
calc_ll <- function(pars) {
X <- log_likelihood(pars = pars,
data = data,
squire_model = squire_model,
model_params = model_params,
interventions = interventions,
pars_obs = pars_obs,
n_particles = n_particles,
forecast_days = 0,
Rt_args = Rt_args,
return = "ll"
)
X
}
#----------------
# create mcmc run functions depending on whether Gibbs Sampling
#----------------
if(gibbs_sampling) {
# checking gibbs days is specified and is an integer
if (is.null(gibbs_days)) {
stop("if gibbs_sampling == TRUE, gibbs_days must be specified")
}
assert_int(gibbs_days)
# create our gibbs run func wrapper
run_mcmc_func <- function(...) {
force(gibbs_days)
run_mcmc_chain_gibbs(..., gibbs_days = gibbs_days)
}
} else {
run_mcmc_func <- run_mcmc_chain
}
#----------------
# proposals
#----------------
# needs to be a vector to pass to reflecting boundary function
pars_min <- unlist(pars_min)
pars_max <- unlist(pars_max)
pars_discrete <- unlist(pars_discrete)
#--------------------------------------------------------
# Section 2 of pMCMC Wrapper: Run pMCMC
#--------------------------------------------------------
# Run the chains in parallel
message("Running pMCMC...")
if (Sys.getenv("SQUIRE_PARALLEL_DEBUG") == "TRUE") {
chains <- purrr::pmap(
.l = list(n_mcmc = rep(n_mcmc, n_chains),
curr_pars = pars_init),
.f = run_mcmc_func,
inputs = inputs,
calc_lprior = calc_lprior,
calc_ll = calc_ll,
first_data_date = data$date[1],
output_proposals = output_proposals,
required_acceptance_ratio = required_acceptance_ratio,
start_adaptation = start_adaptation,
proposal_kernel = proposal_kernel,
scaling_factor = scaling_factor,
pars_discrete = pars_discrete,
pars_min = pars_min,
pars_max = pars_max)
} else {
chains <- furrr::future_pmap(
.l = list(n_mcmc = rep(n_mcmc, n_chains),
curr_pars = pars_init),
.f = run_mcmc_func,
inputs = inputs,
calc_lprior = calc_lprior,
calc_ll = calc_ll,
first_data_date = data$date[1],
output_proposals = output_proposals,
required_acceptance_ratio = required_acceptance_ratio,
start_adaptation = start_adaptation,
proposal_kernel = proposal_kernel,
scaling_factor = scaling_factor,
pars_discrete = pars_discrete,
pars_min = pars_min,
pars_max = pars_max,
.progress = TRUE,
.options = furrr::furrr_options(seed = NULL))
}
#----------------
# MCMC diagnostics and tidy
#----------------
if (n_chains > 1) {
names(chains) <- paste0('chain', seq_len(n_chains))
# calculating rhat
# convert parallel chains to a coda-friendly format
chains_coda <- lapply(chains, function(x) {
traces <- x$results
if('start_date' %in% names(pars_init[[1]])) {
traces$start_date <- start_date_to_offset(data$date[1], traces$start_date)
}
coda::as.mcmc(traces[, names(pars_init[[1]])])
})
rhat <- tryCatch(expr = {
x <- coda::gelman.diag(chains_coda)
x
}, error = function(e) {
message('unable to calculate rhat')
})
pmcmc <- list(inputs = chains[[1]]$inputs,
rhat = rhat,
chains = lapply(chains, '[', -1))
class(pmcmc) <- 'squire_pmcmc_list'
} else {
pmcmc <- chains[[1]]
class(pmcmc) <- "squire_pmcmc"
}
#--------------------------------------------------------
# Section 3 of pMCMC Wrapper: Sample PMCMC Results
#--------------------------------------------------------
pmcmc_samples <- sample_pmcmc(pmcmc_results = pmcmc,
burnin = burnin,
n_chains = n_chains,
n_trajectories = replicates,
log_likelihood = log_likelihood,
n_particles = n_particles,
forecast_days = forecast)
#--------------------------------------------------------
# Section 4 of pMCMC Wrapper: Tidy Output
#--------------------------------------------------------
#----------------
# Pull Sampled results and "recreate" squire models
#----------------
# create a fake run object and fill in the required elements
r <- squire_model$run_func(country = country,
contact_matrix_set = contact_matrix_set,
tt_contact_matrix = tt_contact_matrix,
hosp_bed_capacity = hosp_bed_capacity,
tt_hosp_beds = tt_hosp_beds,
ICU_bed_capacity = ICU_bed_capacity,
tt_ICU_beds = tt_ICU_beds,
max_vaccine = max_vaccine,
tt_vaccine = tt_vaccine,
vaccine_efficacy_infection = vaccine_efficacy_infection,
tt_vaccine_efficacy_infection = tt_vaccine_efficacy_infection,
vaccine_efficacy_disease = vaccine_efficacy_disease,
tt_vaccine_efficacy_disease = tt_vaccine_efficacy_disease,
population = population,
replicates = 1,
day_return = TRUE,
time_period = nrow(pmcmc_samples$trajectories),
...)
# and add the parameters that changed between each simulation, i.e. posterior draws
r$replicate_parameters <- pmcmc_samples$sampled_PMCMC_Results
# as well as adding the pmcmc chains so it's easy to draw from the chains again in the future
r$pmcmc_results <- pmcmc
# then let's create the output that we are going to use
names(pmcmc_samples)[names(pmcmc_samples) == "trajectories"] <- "output"
dimnames(pmcmc_samples$output) <- list(dimnames(pmcmc_samples$output)[[1]], dimnames(r$output)[[2]], NULL)
r$output <- pmcmc_samples$output
# and adjust the time as before
full_row <- match(0, apply(r$output[,"time",],2,function(x) { sum(is.na(x)) }))
saved_full <- r$output[,"time",full_row]
for(i in seq_len(replicates)) {
na_pos <- which(is.na(r$output[,"time",i]))
full_to_place <- saved_full - which(rownames(r$output) == as.Date(max(data$date))) + 1L
if(length(na_pos) > 0) {
full_to_place[na_pos] <- NA
}
r$output[,"time",i] <- full_to_place
}
# second let's recreate the output
r$model <- pmcmc_samples$inputs$squire_model$odin_model(
user = pmcmc_samples$inputs$model_params, unused_user_action = "ignore"
)
# we will add the interventions here so that we know what times are needed for projection
r$interventions <- interventions
# and fix the replicates
r$parameters$replicates <- replicates
r$parameters$time_period <- as.numeric(diff(as.Date(range(rownames(r$output)))))
r$parameters$dt <- model_params$dt
#--------------------..
# out
#--------------------..
return(r)
}
#' Run a single pMCMC chain
#'
#' Helper function to run the particle filter with a
#' new R0 and start date for given interventions within the pmcmc
#'
#' @importFrom stats cov
#' @noRd
run_mcmc_chain <- function(inputs,
curr_pars,
calc_lprior,
calc_ll,
n_mcmc,
first_data_date,
output_proposals,
required_acceptance_ratio,
start_adaptation,
proposal_kernel,
scaling_factor,
pars_discrete,
pars_min,
pars_max) {
#----------------
# Set initial state
#----------------
# run particle filter on initial parameters
p_filter_est <- calc_ll(pars = curr_pars)
# NB, squire originally set up to deal with date format triggering
# however, proposal and log prior set up to deal with numerics, need to change
curr_pars[["start_date"]] <- -(start_date_to_offset(first_data_date, curr_pars[["start_date"]]))
curr_pars <- unlist(curr_pars)
## calculate initial prior
curr_lprior <- calc_lprior(pars = curr_pars)
#----------------..
# assertions and checks on log_prior and log_likelihood functions
#----------------..
if(length(curr_lprior) > 1) {
stop('log_prior must return a single numeric representing the log prior')
}
if(is.infinite(curr_lprior)) {
stop('initial parameters are not compatible with supplied prior')
}
if(length(p_filter_est) != 2) {
stop('log_likelihood function must return a list containing elements log_likelihood and sample_state')
}
if(!setequal(names(p_filter_est), c('log_likelihood', 'sample_state'))) {
stop('log_likelihood function must return a list containing elements log_likelihood and sample_state')
}
if(length(p_filter_est$log_likelihood) > 1) {
stop('log_likelihood must be a single numeric representing the estimated log likelihood')
}
assert_neg(x = p_filter_est$log_likelihood,
message1 = 'log_likelihood must be negative or zero')
#----------------..
# Create objects to store outputs
#----------------..
# extract loglikelihood estimate and sample state
# calculate posterior
curr_ll <- p_filter_est$log_likelihood
curr_lpost <- curr_lprior + curr_ll
curr_ss <- p_filter_est$sample_state
# initialise output arrays
res_init <- c(curr_pars,
'log_prior' = curr_lprior,
'log_likelihood' = curr_ll,
'log_posterior' = curr_lpost)
res <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(res_init),
dimnames = list(NULL, names(res_init)))
states <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(curr_ss))
# New storage arrays for Robbins-Munro optimisation
# storage for acceptances over time
acceptances <- vector(mode = "numeric", length = n_mcmc) # tracks acceptances over time
# storage for covariance matrices over time - only properly initalised if we're actually adapting
# i.e. in instances where start_adaptation < n_mcmc
if (n_mcmc - start_adaptation <= 0) {
covariance_matrix_storage <- vector(mode = "list", length = 1)
} else {
covariance_matrix_storage <- vector(mode = "list", length = (n_mcmc - start_adaptation + 1))
}
# storage for scaling factor over time - only properly initalised if we're actually adapting
# i.e. in instances where start_adaptation < n_mcmc
if (n_mcmc - start_adaptation <= 0) {
scaling_factor_storage <- vector(mode = "numeric", length = 1)
} else {
scaling_factor_storage <- vector(mode = "numeric", length = (n_mcmc - start_adaptation + 1))
}
if(output_proposals) {
proposals <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(res_init) + 1L,
dimnames = list(NULL, c(names(res_init), 'accept_prob')))
}
## record initial results
res[1, ] <- res_init
states[1, ] <- curr_ss
# negative here because we are working backwards in time
pars_min[["start_date"]] <- -(start_date_to_offset(first_data_date, pars_min[["start_date"]]))
pars_max[["start_date"]] <- -(start_date_to_offset(first_data_date, pars_max[["start_date"]]))
#----------------
# main pmcmc loop
#----------------
for(iter in seq_len(n_mcmc) + 1L) {
prop_pars <- propose_parameters(curr_pars,
proposal_kernel * scaling_factor,
unlist(pars_discrete),
unlist(pars_min),
unlist(pars_max))
prop_for_eval <- prop_pars$for_eval
prop_for_chain <- prop_pars$for_chain
## calculate proposed prior / lhood / posterior
prop_lprior <- calc_lprior(pars = prop_for_eval)
prop_pars.squire <- as.list(prop_for_eval)
prop_pars.squire[["start_date"]] <- offset_to_start_date(first_data_date, prop_for_eval[["start_date"]]) # convert to date
p_filter_est <- calc_ll(pars = prop_pars.squire)
prop_ll <- p_filter_est$log_likelihood
prop_ss <- p_filter_est$sample_state
prop_lpost <- prop_lprior + prop_ll
# calculate probability of acceptance
accept_prob <- exp(prop_lpost - curr_lpost)
if(runif(1) < accept_prob) { # MH step
# update parameters and calculated likelihoods
curr_pars <- prop_for_chain
curr_lprior <- prop_lprior
curr_ll <- prop_ll
curr_lpost <- prop_lpost
curr_ss <- prop_ss
acceptances[iter] <- 1
}
# record results
res[iter, ] <- c(curr_pars,
curr_lprior,
curr_ll,
curr_lpost)
states[iter, ] <- curr_ss
# adapt and update covariance matrix
if (iter >= start_adaptation) {
timing_cov <- iter - start_adaptation + 1 # iteration relative to when covariance adaptation started
if (iter == start_adaptation) {
previous_mu <- matrix(colMeans(res[1:iter, seq_along(curr_pars)]), nrow = 1)
current_parameters <- matrix(curr_pars, nrow = 1)
temp <- jc_prop_update(acceptances[iter], timing_cov, scaling_factor, previous_mu,
curr_pars, proposal_kernel, required_acceptance_ratio)
scaling_factor <- temp$scaling_factor
proposal_kernel <- temp$covariance_matrix
previous_mu <- temp$mu
covariance_matrix_storage[[timing_cov]] <- proposal_kernel
scaling_factor_storage[[timing_cov]] <- scaling_factor
} else {
temp <- jc_prop_update(acceptances[iter], timing_cov, scaling_factor, previous_mu,
curr_pars, proposal_kernel, required_acceptance_ratio)
scaling_factor <- temp$scaling_factor
proposal_kernel <- temp$covariance_matrix
previous_mu <- temp$mu
covariance_matrix_storage[[timing_cov]] <- proposal_kernel
scaling_factor_storage[[timing_cov]] <- scaling_factor
}
}
if(output_proposals) {
proposals[iter, ] <- c(prop_for_chain,
prop_lprior,
prop_ll,
prop_lpost,
min(accept_prob, 1))
}
if (iter %% 100 == 0) {
print(c(round(scaling_factor, 3), round(sum(acceptances, na.rm = TRUE)/iter, 3), round(iter, 1)))
}
}
res <- as.data.frame(res)
coda_res <- coda::as.mcmc(res)
rejection_rate <- coda::rejectionRate(coda_res)
ess <- coda::effectiveSize(coda_res)
# res$start_date <- offset_to_start_date(first_data_date, res$start_date)
out <- list('inputs' = inputs,
'results' = as.data.frame(res),
'states' = states,
'acceptance_rate' = 1-rejection_rate,
"ess" = ess,
"scaling_factor" = scaling_factor_storage,
"covariance_matrix" = covariance_matrix_storage,
"acceptance_ratio" = mean(acceptances),
"acceptances" = acceptances)
if(output_proposals) {
proposals <- as.data.frame(proposals)
proposals$start_date <- offset_to_start_date(first_data_date, proposals$start_date)
out$proposals <- proposals
}
out
}
#' Run a single pMCMC chain
#'
#' Helper function to run the particle filter with a
#' new R0 and start date for given interventions within the pmcmc
#'
#' @importFrom stats cov
#' @noRd
run_mcmc_chain_gibbs <- function(inputs,
curr_pars,
calc_lprior,
calc_ll,
n_mcmc,
first_data_date,
output_proposals,
required_acceptance_ratio,
start_adaptation,
proposal_kernel,
scaling_factor,
pars_discrete,
pars_min,
pars_max,
gibbs_days) {
#----------------
# Set initial state
#----------------
# run particle filter on initial parameters
p_filter_est <- calc_ll(pars = curr_pars)
# NB, squire originally set up to deal with date format triggering
# however, proposal and log prior set up to deal with numerics, need to change
curr_pars[["start_date"]] <- -(start_date_to_offset(first_data_date, curr_pars[["start_date"]]))
curr_pars <- unlist(curr_pars)
## calculate initial prior
curr_lprior <- calc_lprior(pars = curr_pars)
#----------------..
# assertions and checks on log_prior and log_likelihood functions
#----------------..
if(length(curr_lprior) > 1) {
stop('log_prior must return a single numeric representing the log prior')
}
if(is.infinite(curr_lprior)) {
stop('initial parameters are not compatible with supplied prior')
}
if(length(p_filter_est) != 2) {
stop('log_likelihood function must return a list containing elements log_likelihood and sample_state')
}
if(!setequal(names(p_filter_est), c('log_likelihood', 'sample_state'))) {
stop('log_likelihood function must return a list containing elements log_likelihood and sample_state')
}
if(length(p_filter_est$log_likelihood) > 1) {
stop('log_likelihood must be a single numeric representing the estimated log likelihood')
}
assert_neg(x = p_filter_est$log_likelihood,
message1 = 'log_likelihood must be negative or zero')
#----------------
# Create objects to store outputs
#----------------
# extract loglikelihood estimate and sample state
# calculate posterior
curr_ll <- p_filter_est$log_likelihood
curr_lpost <- curr_lprior + curr_ll
curr_ss <- p_filter_est$sample_state
# initialise output arrays
res_init <- c(curr_pars,
'log_prior' = curr_lprior,
'log_likelihood' = curr_ll,
'log_posterior' = curr_lpost)
res <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(res_init),
dimnames = list(NULL, names(res_init)))
states <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(curr_ss))
# New storage arrays for Robbins-Munro optimisation
# storage for acceptances over time
acceptances <- vector(mode = "numeric", length = n_mcmc) # tracks acceptances over time
# storage for covariance matrices over time - only properly initalised if we're actually adapting
# i.e. in instances where start_adaptation < n_mcmc
if (n_mcmc - start_adaptation <= 0) {
covariance_matrix_storage <- vector(mode = "list", length = 1)
} else {
covariance_matrix_storage <- vector(mode = "list", length = (n_mcmc - start_adaptation + 1))
}
# storage for scaling factor over time - only properly initalised if we're actually adapting
# i.e. in instances where start_adaptation < n_mcmc
if (n_mcmc - start_adaptation <= 0) {
scaling_factor_storage <- vector(mode = "numeric", length = 1)
} else {
scaling_factor_storage <- vector(mode = "numeric", length = (n_mcmc - start_adaptation + 1))
}
if(output_proposals) {
proposals <- matrix(data = NA,
nrow = n_mcmc + 1L,
ncol = length(res_init) + 1L,
dimnames = list(NULL, c(names(res_init), 'accept_prob')))
}
## record initial results
res[1, ] <- res_init
states[1, ] <- curr_ss
# negative here because we are working backwards in time
pars_min[["start_date"]] <- -(start_date_to_offset(first_data_date, pars_min[["start_date"]]))
pars_max[["start_date"]] <- -(start_date_to_offset(first_data_date, pars_max[["start_date"]]))
#----------------
# main pmcmc loop
#----------------
for(iter in seq_len(n_mcmc) + 1L) {
# discrete parameter (start_date) update first
current_start_date <- curr_pars["start_date"]
prop_pars <- curr_pars # return to this
prop_pars.squire <- as.list(prop_pars) # return to this
total_days <- 2 * gibbs_days + 1
gibbs_post <- vector(mode = "numeric", length = total_days)
for (i in 1:total_days) {
gibbs_start_date <- current_start_date - gibbs_days + i - 1
prop_pars[["start_date"]] <- gibbs_start_date
prop_lprior <- calc_lprior(pars = prop_pars)
if(is.infinite(prop_lprior)) {
gibbs_post[i] <- -Inf # check this or maybe just very small number
} else {
prop_pars.squire[["start_date"]] <- offset_to_start_date(first_data_date, gibbs_start_date)
p_filter_est <- calc_ll(pars = prop_pars.squire)
prop_ll <- p_filter_est$log_likelihood
# prop_ss <- p_filter_est$sample_state ignoring now as feasibly each MCMC row might have two states due to the blocking
prop_lpost <- prop_lprior + prop_ll
gibbs_post[i] <- prop_lpost
}
}
best <- max(gibbs_post)
probs <- exp(gibbs_post - best)
probs <- probs/sum(probs)
new_start_date <- current_start_date - gibbs_days + sample(1:total_days, 1, prob = probs) - 1
curr_pars["start_date"] <- new_start_date
# then continuous parameter updates
prop_pars <- propose_parameters(curr_pars[-1], # remove start date - return to this with a better way
proposal_kernel * scaling_factor,
unlist(pars_discrete)[-1],
unlist(pars_min)[-1],
unlist(pars_max)[-1])
prop_for_eval <- c(curr_pars["start_date"], prop_pars$for_eval) # these should be identical now
prop_for_chain <- c(curr_pars["start_date"], prop_pars$for_chain) # these should be identical now
## calculate proposed prior / lhood / posterior
prop_lprior <- calc_lprior(pars = prop_for_eval)
prop_pars.squire <- as.list(prop_for_eval)
prop_pars.squire[["start_date"]] <- offset_to_start_date(first_data_date, prop_for_eval[["start_date"]]) # convert to date
p_filter_est <- calc_ll(pars = prop_pars.squire)
prop_ll <- p_filter_est$log_likelihood
prop_ss <- p_filter_est$sample_state
prop_lpost <- prop_lprior + prop_ll
# calculate probability of acceptance
accept_prob <- exp(prop_lpost - curr_lpost)
if(runif(1) < accept_prob) { # MH step
# update parameters and calculated likelihoods
curr_pars <- prop_for_chain
curr_lprior <- prop_lprior
curr_ll <- prop_ll
curr_lpost <- prop_lpost
curr_ss <- prop_ss
acceptances[iter] <- 1
}
# record results
res[iter, ] <- c(curr_pars,
curr_lprior,
curr_ll,
curr_lpost)
states[iter, ] <- curr_ss
# adapt and update covariance matrix
if (iter >= start_adaptation) {
timing_cov <- iter - start_adaptation + 1 # iteration relative to when covariance adaptation started
if (iter == start_adaptation) {
previous_mu <- matrix(colMeans(res[1:iter, seq_along(curr_pars[-1])]), nrow = 1)
current_parameters <- matrix(curr_pars[-1], nrow = 1)
temp <- jc_prop_update(acceptances[iter], timing_cov, scaling_factor, previous_mu,
curr_pars[-1], proposal_kernel, required_acceptance_ratio)
scaling_factor <- temp$scaling_factor
proposal_kernel <- temp$covariance_matrix
previous_mu <- temp$mu
covariance_matrix_storage[[timing_cov]] <- proposal_kernel
scaling_factor_storage[[timing_cov]] <- scaling_factor
} else {
temp <- jc_prop_update(acceptances[iter], timing_cov, scaling_factor, previous_mu,
curr_pars[-1], proposal_kernel, required_acceptance_ratio)
scaling_factor <- temp$scaling_factor
proposal_kernel <- temp$covariance_matrix
previous_mu <- temp$mu
covariance_matrix_storage[[timing_cov]] <- proposal_kernel
scaling_factor_storage[[timing_cov]] <- scaling_factor
}
}
if(output_proposals) {
proposals[iter, ] <- c(prop_for_chain,
prop_lprior,
prop_ll,
prop_lpost,
min(accept_prob, 1))
}
if (iter %% 100 == 0) {
print(c(round(scaling_factor, 3), round(sum(acceptances, na.rm = TRUE)/iter, 3), round(iter, 1)))
}
}
res <- as.data.frame(res)
coda_res <- coda::as.mcmc(res)
rejection_rate <- coda::rejectionRate(coda_res)
ess <- coda::effectiveSize(coda_res)
# res$start_date <- offset_to_start_date(first_data_date, res$start_date)
out <- list('inputs' = inputs,
'results' = as.data.frame(res),
'states' = states,
'acceptance_rate' = 1-rejection_rate,
"ess" = ess,
"scaling_factor" = scaling_factor_storage,
"covariance_matrix" = covariance_matrix_storage,
"acceptance_ratio" = mean(acceptances),
"acceptances" = acceptances)
if(output_proposals) {
proposals <- as.data.frame(proposals)
proposals$start_date <- offset_to_start_date(first_data_date, proposals$start_date)
out$proposals <- proposals
}
out
}
# Run odin model to calculate log-likelihood
# return: Set to 'll' to return the log-likelihood (for MCMC) or to
#
calc_loglikelihood <- function(pars, data, squire_model, model_params,
pars_obs, n_particles,
forecast_days = 0, return = "ll",
Rt_args,
interventions) {
#----------------..
# specify particle setup
#----------------..
switch(return,
"full" = {
save_particles <- TRUE
full_output <- TRUE
pf_return <- "sample"
},
"ll" = {
save_particles <- FALSE
forecast_days <- 0
full_output <- FALSE
pf_return <- "single"
},
{
stop("Unknown return type to calc_loglikelihood")
}
)
#----------------..
# (potentially redundant) assertion
#----------------..
assert_in(c("R0", "start_date"), names(pars),
message = "Must specify R0, start date to infer")
#----------------..
# unpack current params
#----------------..
R0 <- pars[["R0"]]
start_date <- pars[["start_date"]]
# reporting fraction par if in pars list
if("rf" %in% names(pars)) {
assert_numeric(pars[["rf"]])
pars_obs$phi_death <- pars[["rf"]]
}
#----------------..
# more assertions
#----------------..
assert_pos(R0)
assert_date(start_date)
#----------------..
# setup model based on inputs and interventions
#----------------..
R0_change <- interventions$R0_change
date_R0_change <- interventions$date_R0_change
date_contact_matrix_set_change <- interventions$date_contact_matrix_set_change
date_ICU_bed_capacity_change <- interventions$date_ICU_bed_capacity_change
date_hosp_bed_capacity_change <- interventions$date_hosp_bed_capacity_change
date_vaccine_change <- interventions$date_vaccine_change
date_vaccine_efficacy_infection_change <- interventions$date_vaccine_efficacy_infection_change
date_vaccine_efficacy_disease_change <- interventions$date_vaccine_efficacy_disease_change
# change betas
if (is.null(date_R0_change)) {
tt_beta <- 0
} else {
tt_list <- intervention_dates_for_odin(dates = date_R0_change,
change = R0_change,
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = 1)
model_params$tt_beta <- tt_list$tt
R0_change <- tt_list$change
date_R0_change <- tt_list$dates
}
# and contact matrixes
if (is.null(date_contact_matrix_set_change)) {
tt_contact_matrix <- 0
} else {
# here just provide positions for change and then use these to index mix_mat_set
tt_list <- intervention_dates_for_odin(dates = date_contact_matrix_set_change,
change = seq_along(interventions$contact_matrix_set)[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = 1)
model_params$tt_matrix <- tt_list$tt
model_params$mix_mat_set <- model_params$mix_mat_set[tt_list$change,,]
}
# and icu beds
if (is.null(date_ICU_bed_capacity_change)) {
tt_ICU_beds <- 0
} else {
tt_list <- intervention_dates_for_odin(dates = date_ICU_bed_capacity_change,
change = interventions$ICU_bed_capacity[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = interventions$ICU_bed_capacity[1])
model_params$tt_ICU_beds <- tt_list$tt
model_params$ICU_beds <- tt_list$change
}
# and hosp beds
if (is.null(date_hosp_bed_capacity_change)) {
tt_hosp_beds <- 0
} else {
tt_list <- intervention_dates_for_odin(dates = date_hosp_bed_capacity_change,
change = interventions$hosp_bed_capacity[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = interventions$hosp_bed_capacity[1])
model_params$tt_hosp_beds <- tt_list$tt
model_params$hosp_beds <- tt_list$change
}
# and vaccine coverage
if (is.null(date_vaccine_change)) {
tt_vaccine <- 0
} else {
tt_list <- intervention_dates_for_odin(dates = date_vaccine_change,
change = interventions$max_vaccine[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = interventions$max_vaccine[1])
model_params$tt_vaccine <- tt_list$tt
model_params$max_vaccine <- tt_list$change
}
# and vaccine efficacy infection
if (is.null(date_vaccine_efficacy_infection_change)) {
tt_vaccine_efficacy_infection <- 0
} else {
# here we just pass the change as a position vector as we need to then
# index the array of vaccine efficacies
tt_list <- intervention_dates_for_odin(dates = date_vaccine_efficacy_infection_change,
change = seq_along(interventions$vaccine_efficacy_infection)[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = 1)
model_params$tt_vaccine_efficacy_infection <- tt_list$tt
# here we have to not index the array by the postion vectors that are reutrned by intervention_dates_for_odin
model_params$vaccine_efficacy_infection <- model_params$vaccine_efficacy_infection[tt_list$change,,]
}
# and vaccine efficacy disease
if (is.null(date_vaccine_efficacy_disease_change)) {
tt_vaccine_efficacy_disease <- 0
} else {
# here we just pass the change as a position vector as we need to then
# index the array of vaccine efficacies
tt_list <- intervention_dates_for_odin(dates = date_vaccine_efficacy_disease_change,
change = seq_along(interventions$vaccine_efficacy_disease)[-1],
start_date = start_date,
steps_per_day = round(1/model_params$dt),
starting_change = 1)
model_params$tt_vaccine_efficacy_disease <- tt_list$tt
# here we have to not index the array by the position vectors that are returned by intervention_dates_for_odin
model_params$prob_hosp <- model_params$prob_hosp[tt_list$change,,]
}
#--------------------..
# update new R0s based on R0_change and R0_date_change, and Meff_date_change
#--------------------..
# and now get new R0s for the R0
R0 <- evaluate_Rt_pmcmc(R0_change = R0_change,
R0 = R0,
date_R0_change = date_R0_change,
pars = pars,
Rt_args = Rt_args)
# which allow us to work out our beta
beta_set <- beta_est(squire_model = squire_model,
model_params = model_params,
R0 = R0)
#----------------..
# update the model params accordingly from new inputs
#----------------..
model_params$beta_set <- beta_set
#----------------..
# run the particle filter
#----------------..
if (inherits(squire_model, "stochastic")) {
pf_result <- run_particle_filter(data = data,
squire_model = squire_model,
model_params = model_params,
model_start_date = start_date,
obs_params = pars_obs,
n_particles = n_particles,
forecast_days = forecast_days,
save_particles = save_particles,
full_output = full_output,
return = pf_return)
} else if (inherits(squire_model, "deterministic")) {
pf_result <- run_deterministic_comparison(data = data,
squire_model = squire_model,
model_params = model_params,
model_start_date = start_date,
obs_params = pars_obs,
forecast_days = forecast_days,
save_history = save_particles,
return = pf_return)
}
# out
pf_result
}
#' @noRd
Rt_args_list <- function(plateau_duration = 7,
date_Meff_change = NULL,
scale_Meff_pl = FALSE,
Rt_shift_duration = 30,
Rt_rw_duration = 14) {
if(!is.null(date_Meff_change)) {
assert_date(date_Meff_change)
}
list(plateau_duration = plateau_duration,
date_Meff_change = date_Meff_change,
scale_Meff_pl = scale_Meff_pl,
Rt_shift_duration = Rt_shift_duration,
Rt_rw_duration = Rt_rw_duration)
}
# Evaluate Rt
#' @noRd
evaluate_Rt_pmcmc <- function(R0_change,
R0,
date_R0_change,
pars,
Rt_args) {
# unpack pars
Meff <- pars[["Meff"]]
Meff_pl <- pars[["Meff_pl"]]
Rt_shift <- pars[["Rt_shift"]]
Rt_shift_scale <- pars[["Rt_shift_scale"]]
# get random walk parameters
if(any(grepl("Rt_rw", names(pars)))) {
Rt_rw_bool <- TRUE
Rt_rws <- pars[grepl("Rt_rw", names(pars))]
} else {
Rt_rw_bool <- FALSE
}
# unpack Rt_args
plateau_duration <- Rt_args[["plateau_duration"]]
date_Meff_change <- Rt_args[["date_Meff_change"]]
scale_meff_pl <- Rt_args[["scale_Meff_pl"]]
Rt_shift_duration <- Rt_args[["Rt_shift_duration"]]
Rt_rw_duration <- Rt_args[["Rt_rw_duration"]]
# and now get new Rts for the R0
if (!is.null(R0_change)) {
# if there is no Meff then we just do a linear transform
if (is.null(Meff)) {
Rt <- R0*R0_change
} else {
# if no date_Meff_change then all from R0_change and Meff
if (is.null(date_Meff_change)) {
Rt <- R0 * (2 * plogis(-(R0_change - 1) * -Meff))
} else if (!is.null(date_Meff_change)) {
date_Meff_change <- as.Date(date_Meff_change)
date_R0_change <- as.Date(date_R0_change)
# scale Meff accordingly
if (!is.null(scale_meff_pl)) {
if (scale_meff_pl) {
Meff_pl <- Meff_pl*Meff
}
}
# if no shift then set to 0
if (is.null(Rt_shift) || is.null(Rt_shift_scale)) {
Rt_shift <- 0
Rt_shift_duration <- 1
} else {
Rt_shift_x <- seq(0, 1, length.out = max(2,Rt_shift_duration))
Rt_shift <- Rt_shift * (1 / (1 + (Rt_shift_x/(1-Rt_shift_x))^-Rt_shift_scale))
if(is.null(Rt_shift_duration)) {
stop("Rt_shift provided but no Rt_shift_duration")
}
}
# when does mobility change take place
if (date_Meff_change > tail(date_R0_change, 1)) {
Rt <- R0 * (2 * plogis(-(R0_change - 1) * -Meff))
} else {
# when is the switch in our data
swtchdates <- which(date_R0_change >= date_Meff_change)
min_d <- as.Date(date_R0_change[min(swtchdates)])
dates_to_median <- seq(min_d - floor(plateau_duration/2),min_d + floor(plateau_duration/2),1)
# Work out the mobility during this period
mob_pld <- median(R0_change[which(date_R0_change %in% dates_to_median)])
mob_up <- c(rep(0, swtchdates[1]-1),
R0_change[min(swtchdates):(length(R0_change))] - mob_pld)
# what does our shift look like
shift_dates <- seq.Date(date_R0_change[swtchdates[1]], (date_R0_change[swtchdates[1]]+Rt_shift_duration-1), 1)
Rt_pl_change <- c(rep(0, min(swtchdates[1]-1)), Rt_shift[shift_dates %in% date_R0_change])
Rt_pl_change <- c(Rt_pl_change,
rep(tail(Rt_shift,1), max(0, length(mob_up) - length(Rt_pl_change))))
Rt_pl_change <- head(Rt_pl_change, length(mob_up))
# if we have random walks
if (Rt_rw_bool) {
# if no duration have as default 14 days
if (is.null(Rt_rw_duration)) {
Rt_rw_duration <- 14
}
# 0 up until the end of the shift
Rt_rw_change <- rep(0, length(mob_up))
# append the rw params
for (i in seq_along(Rt_rws)) {
rw_dates <- date_R0_change[swtchdates[1]]+Rt_shift_duration + ((i-1)*Rt_rw_duration)
pos_i <- which(date_R0_change > rw_dates)
Rt_rw_change[pos_i] <- Rt_rw_change[pos_i] + Rt_rws[[i]]
}
# take the head if it overruns
Rt_rw_change <- head(Rt_rw_change, length(mob_up))
# fill the end if too short
if (length(Rt_rw_change) < length(Rt_pl_change)) {
Rt_rw_change <- c(Rt_rw_change, rep(tail(Rt_rw_change, 1), length(Rt_pl_change) - length(Rt_rw_change)))
}
} else {
Rt_rw_change <- rep(0, length(Rt_pl_change))
}
# now work out Rt forwards based on mobility increasing from this plateau
Rt <- R0 * 2*(plogis( -Meff * -(R0_change-1) - Meff_pl*(mob_up) - Rt_pl_change - Rt_rw_change))
}
}
}
} else {
Rt <- rep(R0, length(date_R0_change))
}
return(Rt)
}
# proposal for MCMC
propose_parameters <- function(pars, proposal_kernel, pars_discrete, pars_min, pars_max) {
assert_same_length(sum(pars_discrete == FALSE), length(pars))
assert_same_length(sum(pars_discrete == FALSE), dim(proposal_kernel)[1])
## proposed jumps are normal with mean pars and sd as input for parameter
proposed <- pars + drop(rmvnorm(n = 1, sigma = proposal_kernel))
for_chain <- reflect_proposal(x = proposed,
floor = pars_min,
cap = pars_max)
# discretise if necessary
for_eval <- proposed
#for_eval[pars_discrete] <- round(for_eval[pars_discrete])
for_eval <- reflect_proposal(x = for_eval,
floor = pars_min,
cap = pars_max)
return(list(for_eval = for_eval, for_chain = for_chain))
}
## create function to reflect proposal boundaries at pars_min and pars_max
# this ensures the proposal is symetrical and we can simplify the MH step
reflect_proposal <- function(x, floor, cap) {
interval <- cap - floor
abs((x + interval - floor) %% (2 * interval) - interval) + floor
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.