#' Fit a stochastic epidemic model using the linear noise approximation or
#' ordinary differential equations to approximate the latent epidemic process.
#'
#' @param stem_object a stochastic epidemic model object containing the dataset,
#' model dynamics, and measurement process.
#' @param method either "lna" or "ode".
#' @param mcmc_kern MCMC transition kernel generated by a call to the
#' \code{mcmc_kernel} function.
#' @param iterations number of iterations
#' @param initialization_attempts number of initialization attempts
#' @param ess_warmup number of preliminary ESS iterations for the LNA, initial
#' conditions, and time varying parameters prior to starting MCMC
#' @param thinning_interval thinning interval for posterior samples, defaults to
#' saving every 100th sample
#' @param return_adapt_rec should the MCMC samples be returned during
#' adaptation? defaults to FALSE.
#' @param return_ess_rec should elliptical slice sampling steps and angles be
#' returned? defaults to FALSE
#' @param print_progress interval at which to print progress to a text file. If
#' 0 (default) progress is not printed.
#' @param status_filename string to pre-append to status files, defaults to LNA
#' or ODE depending on the method used.
#'
#' @return list with posterior samples for the parameters and the latent
#' process, along with MCMC diagnostics.
#' @export
fit_stem =
function(stem_object,
method,
mcmc_kern,
iterations,
initialization_attempts = 500,
ess_warmup = 50,
thinning_interval = 100,
return_adapt_rec = FALSE,
return_ess_rec = FALSE,
print_progress = 0,
status_filename = NULL) {
# check that the data, dynamics and measurement process are all supplied
if(is.null(stem_object$measurement_process$data) ||
is.null(stem_object$dynamics) ||
is.null(stem_object$measurement_process)) {
stop("The dataset, dynamics, and measurement process must all be specified.")
}
# check that the appropriate model object is compiled
if(method == "lna") {
if(is.null(stem_object$dynamics$lna_pointers)) {
stop("LNA is not compiled.")
}
if(is.null(status_filename)) status_filename <- "LNA"
} else if(method == "ode") {
if(is.null(stem_object$dynamics$ode_pointers)) {
stop("ODE is not compiled.")
}
if(is.null(status_filename)) status_filename <- "ODE"
}
# if the MCMC is being restarted, save the existing results
mcmc_restart <- !is.null(stem_object$restart)
# grab time-varying parameters
if(mcmc_restart) {
path <- stem_object$restart$path
tparam <- stem_object$restart$tparam
} else {
path <- NULL
tparam <- stem_object$dynamics$tparam
}
# grab parameter names
parameters <- stem_object$dynamics$parameters
param_names_nat <- names(stem_object$dynamics$param_codes)[!grepl("_0", names(stem_object$dynamics$param_codes))]
initdist_names <- names(stem_object$dynamics$param_codes)[grepl("_0", names(stem_object$dynamics$param_codes))]
param_names_est <- c(sapply(mcmc_kern$parameter_blocks, function(x) x$pars_est))
n_model_params <- length(param_names_est)
# unpack mcmc kernel
param_blocks <- mcmc_kern$parameter_blocks
lna_ess_control <- mcmc_kern$lna_ess_control
initdist_ess_control <- mcmc_kern$initdist_ess_control
tparam_ess_control <- mcmc_kern$tparam_ess_control
# parameter codes
if(method == "lna") {
n_pars_tot <- length(stem_object$dynamics$lna_rates$lna_param_codes)
param_codes <- stem_object$dynamics$lna_rates$lna_param_codes
} else {
n_pars_tot <- length(stem_object$dynamics$ode_rates$ode_param_codes)
param_codes <- stem_object$dynamics$ode_rates$ode_param_codes
}
# prepare param_blocks
param_blocks <-
prepare_param_blocks(
param_blocks = param_blocks,
parameters = parameters,
param_codes = param_codes,
iterations = iterations,
mcmc_restart = mcmc_restart)
# maximum number of adaptive iterations
max_adaptation <- max(sapply(param_blocks, function(x) x$control$stop_adaptation))
# number of posterior samples
n_samples <-
if(return_adapt_rec) {
floor(iterations / thinning_interval)
} else {
if(max_adaptation %% thinning_interval != 0) {
ceiling((iterations - max_adaptation) / thinning_interval)
} else {
floor((iterations - max_adaptation) / thinning_interval)
}
}
n_ess_recs <- n_samples
record_sample <- FALSE
# progress printing interval
if(print_progress != 0) {
progress_interval <- print_progress
print_progress <- TRUE
} else {
progress_interval <- NULL
print_progress <- FALSE
}
### Unpack stem_object-------------------------
# dynamics that are not method specific
censusmat <- stem_object$measurement_process$censusmat
constants <- stem_object$dynamics$constants
initializer <- stem_object$dynamics$initializer
fixed_inits <- stem_object$dynamics$fixed_inits
n_strata <- stem_object$dynamics$n_strata
step_size <- stem_object$dynamics$dynamics_args$step_size
# measurement process objects that are not method specific
measproc_indmat <- stem_object$measurement_process$measproc_indmat
dat <- stem_object$measurement_process$obsmat
obstimes <- stem_object$measurement_process$obstimes
# initialize the ess_record
ess_record <- list(lna_ess_record = NULL,
initdist_ess_record = NULL,
tparam_ess_record = NULL)
# dynamics and measurement process objects that are method specific
if(method == "lna") {
# extract objects from dynamics
flow_matrix <- stem_object$dynamics$flow_matrix_lna
n_compartments <- ncol(flow_matrix)
n_rates <- nrow(flow_matrix)
stoich_matrix <- stem_object$dynamics$stoich_matrix_lna
proc_pointer <- stem_object$dynamics$lna_pointers$lna_ptr
set_pars_pointer <- stem_object$dynamics$lna_pointers$set_lna_params_ptr
do_prevalence <- stem_object$measurement_process$lna_prevalence
event_inds <- stem_object$measurement_process$incidence_codes_lna
initdist_inds <- stem_object$dynamics$lna_initdist_inds
approx_warmup <- lna_ess_control$approx_warmup
# measurement process
d_meas_pointer <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr
# indices of parameters, constants, and time-varying covariates
param_inds <-
setdiff(stem_object$dynamics$param_codes,
stem_object$dynamics$lna_initdist_inds)
const_inds <-
length(stem_object$dynamics$param_codes) +
seq_along(stem_object$dynamics$const_codes) - 1
tcovar_inds <-
length(stem_object$dynamics$param_codes) +
length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1
# should initial concentrations be updated jointly with the LNA path
joint_initdist_update <- !fixed_inits & lna_ess_control$joint_initdist_update
# objects for computing the SVD of the LNA diffusion matrix
svd_U <- diag(0.0, n_rates)
svd_V <- diag(0.0, n_rates)
svd_d <- rep(0.0, n_rates)
# grab the tparam indices and update scheme
if(!is.null(tparam)) {
tparam_inds <-
stem_object$dynamics$lna_rates$lna_param_codes[
sapply(tparam, function(x) x$tparam_name)]
} else {
tparam_inds <- NULL
}
lna_ess_schedule <-
prepare_lna_ess_schedule(
stem_object = stem_object,
initializer = initializer,
lna_ess_control = lna_ess_control)
if(return_ess_rec) {
ess_record$lna_ess_record <-
list(ess_steps =
array(1.0,
dim = c(lna_ess_control$n_updates,
length(lna_ess_schedule),
n_ess_recs)),
ess_angles =
array(1.0,
dim = c(lna_ess_control$n_updates,
length(lna_ess_schedule),
n_ess_recs)))
}
} else if(method == "ode") {
# extract objects from dynamics
flow_matrix <- stem_object$dynamics$flow_matrix_ode
n_compartments <- ncol(flow_matrix)
n_rates <- nrow(flow_matrix)
stoich_matrix <- stem_object$dynamics$stoich_matrix_ode
proc_pointer <- stem_object$dynamics$ode_pointers$ode_ptr
set_pars_pointer <- stem_object$dynamics$ode_pointers$set_ode_params_ptr
do_prevalence <- stem_object$measurement_process$ode_prevalence
event_inds <- stem_object$measurement_process$incidence_codes_ode
initdist_inds <- stem_object$dynamics$ode_initdist_inds
# measurement process
d_meas_pointer <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr
# indices of parameters, constants, and time-varying covariates
param_inds <-
setdiff(stem_object$dynamics$param_codes,
stem_object$dynamics$ode_initdist_inds)
const_inds <-
length(stem_object$dynamics$param_codes) +
seq_along(stem_object$dynamics$const_codes) - 1
tcovar_inds <-
length(stem_object$dynamics$param_codes) +
length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1
# obviously, initial distribution is its own ESS update
joint_initdist_update <- FALSE
# grab the tparam indices
if(!is.null(tparam)) {
tparam_inds <-
stem_object$dynamics$ode_rates$ode_param_codes[
sapply(tparam, function(x) x$tparam_name)]
} else {
tparam_inds <- NULL
}
# set some LNA objects to NULL
svd_U <- NULL
svd_V <- NULL
svd_d <- NULL
lna_ess_schedule <- NULL
}
### Initial distribution objects --------------------------------------------
if(n_strata == 1) {
comp_size_vec <- constants["popsize"]
} else {
comp_size_vec <- constants[paste0("popsize_", sapply(initializer,"[[","strata"))]
}
# list for initial compartment volume objects
if(mcmc_restart) {
initdist_objects <- stem_object$restart$initdist_objects
} else {
initdist_objects <-
prepare_initdist_objects(initializer = initializer,
param_codes = param_codes,
comp_size_vec = comp_size_vec)
}
# initialize initdist draws
if(!fixed_inits) {
# for recording the ess initdist updates
if(!joint_initdist_update) {
initdist_ess_control$steps <- rep(1.0, initdist_ess_control$n_updates)
initdist_ess_control$angles <- rep(0.0, initdist_ess_control$n_updates)
initdist_ess_control$angle_mean <- 0.0
initdist_ess_control$angle_resid <- 0.0
initdist_ess_control$angle_var <- 0.0
}
if(return_ess_rec) {
if(method != "lna" | !joint_initdist_update) {
ess_record$initdist_ess_record <-
list(ess_steps =
matrix(1.0,
nrow = initdist_ess_control$n_updates,
ncol = n_ess_recs),
ess_angles =
matrix(0.0,
nrow = initdist_ess_control$n_updates,
ncol = n_ess_recs))
}
}
}
# vector of census times
census_times <-
sort(unique(c(obstimes,
stem_object$dynamics$tcovar[, 1],
seq(stem_object$dynamics$t0,
stem_object$dynamics$tmax,
by = stem_object$dynamics$timestep),
stem_object$dynamics$tmax)))
census_indices <- unique(c(0, findInterval(obstimes, census_times) - 1))
n_times <- length(census_times)
# make sure no times are less than t0
if(any(obstimes < stem_object$dynamics$t0)) {
stop("Cannot have observations before time t0.")
}
if(any(stem_object$dynamics$tcovar[,1] < stem_object$dynamics$t0)) {
stop("Cannot have time-varying covariates specified before time t0.")
}
### Set up parameter objects ---------------------------------
parmat <-
matrix(0.0,
nrow = n_times,
ncol = n_pars_tot,
dimnames = list(NULL, names(param_codes)))
# insert parameters into the parameter matrix
insert_params(parmat = parmat,
param_blocks = param_blocks,
nat = TRUE,
prop = FALSE,
rowind = 0)
# insert initial conditions
insert_initdist(parmat = parmat,
initdist_objects = initdist_objects,
prop = FALSE,
rowind = 0,
mcmc_rec = FALSE)
# insert the constants
parmat[, const_inds + 1] <-
matrix(stem_object$dynamics$constants,
nrow = nrow(parmat),
ncol = length(const_inds), byrow = T)
# generate forcing indices and other objects
forcing_inds <- rep(FALSE, length(census_times))
# insert time varying covariates
if(!is.null(stem_object$dynamics$tcovar)) {
tcovar_rowinds <-
findInterval(census_times, stem_object$dynamics$tcovar[, 1])
parmat[, tcovar_inds + 1] <-
stem_object$dynamics$tcovar[tcovar_rowinds, -1]
# zero out forcings if necessary
if(!is.null(stem_object$dynamics$forcings)) {
# get the forcing indices (supplied in the original tcovar matrix)
for(f in seq_along(stem_object$dynamics$forcings)) {
forcing_inds <-
forcing_inds |
stem_object$dynamics$tcovar[stem_object$dynamics$tcovar[,1] %in% census_times,
stem_object$dynamics$forcings[[f]]$tcovar_name] != 0
}
zero_inds <- !forcing_inds
# zero out the tcovar elements corresponding to times with no forcings
for(l in seq_along(stem_object$dynamics$dynamics_args$forcings)) {
parmat[zero_inds, stem_object$dynamics$dynamics_args$forcings[[l]]$tcovar_name] <- 0
}
}
}
# get indices for time-varying parameters
if (!is.null(tparam)) {
# verify whether the mcmc is being restarted
if (!mcmc_restart) {
# generate the indices for updating the time-varying parameter and initialize the values
for (s in seq_along(tparam)) {
# can get rid of the values slot
tparam[[s]]$values <- NULL
# indices
tparam[[s]]$col_ind <- param_codes[tparam[[s]]$tparam_name]
tparam[[s]]$tpar_inds_R <-
findInterval(census_times, tparam[[s]]$times, left.open = F)
tparam[[s]]$tpar_inds_R[tparam[[s]]$tpar_inds_R == 0] <- 1
tparam[[s]]$tpar_inds_Cpp <- tparam[[s]]$tpar_inds_R - 1
# values
tparam[[s]]$draws_cur <- rnorm(tparam[[s]]$n_draws)
tparam[[s]]$draws_prop <- rnorm(tparam[[s]]$n_draws)
tparam[[s]]$draws_ess <- rnorm(tparam[[s]]$n_draws)
tparam[[s]]$log_lik <- sum(dnorm(tparam[[s]]$draws_cur, log = TRUE))
# get values
tparam[[s]]$tpar_cur <-
tparam[[s]]$draws2par(
parameters = parmat[1,],
draws = tparam[[s]]$draws_cur)[tparam[[s]]$tpar_inds_R]
# insert into the parameter matrix
vec_2_mat(dest = parmat,
orig = tparam[[s]]$tpar_cur,
ind = tparam[[s]]$col_ind)
# bracket width
tparam[[s]]$bracket_width <- tparam_ess_control$bracket_width
# ess counters
tparam[[s]]$steps <- rep(1.0, tparam_ess_control$n_updates)
tparam[[s]]$angles <- rep(0.0, tparam_ess_control$n_updates)
# for tuning the bracket
tparam[[s]]$angle_mean <- 0
tparam[[s]]$angle_var <- pi^2/3
tparam[[s]]$angle_resid <- 0
}
} else {
# if restarting, just copy the values into the parameter matrices
for (s in seq_along(tparam)) {
# insert into the parameter matrix
vec_2_mat(dest = parmat,
orig = tparam[[s]]$tpar_cur,
ind = tparam[[s]]$col_ind)
}
}
# check if tparam[[s]] depends on initial conditions
tparam <- check_tpar_depends(tparam = tparam,
initdist_objects = initdist_objects,
parmat = parmat)
if(return_ess_rec) {
ess_record$tparam_ess_record <-
list(ess_steps =
array(1.0,
dim = c(tparam_ess_control$n_updates,
length(tparam),
n_ess_recs)),
ess_angles =
array(1.0,
dim = c(tparam_ess_control$n_updates,
length(tparam),
n_ess_recs)))
}
} else {
tparam <- NULL
}
# indices for when to update the parameters
param_update_inds <- rep(FALSE, length(census_times))
param_update_inds[1] <- TRUE
if(!is.null(stem_object$dynamics$tcovar)) {
param_update_inds[census_times %in% stem_object$dynamics$tcovar[,1]] <- TRUE
}
if(!is.null(tparam)) {
for(s in seq_along(tparam)) {
param_update_inds[census_times %in% tparam[[s]]$times] <- TRUE
}
}
if(length(param_update_inds) == nrow(stem_object$dynamics$tcovar_changemat)) {
param_update_inds <-
param_update_inds | apply(stem_object$dynamics$tcovar_changemat, 1, any)
}
# generate forcing objects
if(!is.null(stem_object$dynamics$forcings)) {
forcings <- stem_object$dynamics$forcings
# names and indices
forcing_tcovars <- sapply(forcings, function(x) x$tcovar_name)
forcing_tcov_inds <- match(forcing_tcovars, colnames(parmat)) - 1
forcing_events <- c(sapply(forcings, function(x) paste0(x$from, "2", x$to)))
# matrix indicating which compartments are involved in which forcings in and out
forcings_out <-
matrix(0.0,
nrow = ncol(flow_matrix),
ncol = length(forcings),
dimnames = list(colnames(flow_matrix),
forcing_tcovars))
forcing_transfers <-
array(0.0,
dim = c(ncol(flow_matrix),
ncol(flow_matrix),
length(forcings)),
dimnames = list(colnames(flow_matrix),
colnames(flow_matrix),
forcing_tcovars))
for(s in seq_along(forcings)) {
forcings_out[forcings[[s]]$from, s] <- 1
for(t in seq_along(forcings[[s]]$from)) {
forcing_transfers[forcings[[s]]$from[t], forcings[[s]]$from[t], s] <- -1
forcing_transfers[forcings[[s]]$to[t], forcings[[s]]$from[t], s] <- 1
}
}
} else {
forcing_tcovars <- character(0L)
forcing_tcov_inds <- integer(0L)
forcing_events <- character(0L)
forcings_out <- matrix(0.0, nrow = 0, ncol = 0)
forcing_transfers <- array(0.0, dim = c(0,0,0))
}
# matrix in which to store the emission probabilities
emitmat <- cbind(dat[, 1, drop = F],
matrix(0.0,
nrow = nrow(measproc_indmat),
ncol = ncol(measproc_indmat),
dimnames = list(NULL, colnames(measproc_indmat))))
pathmat_prop <- cbind(census_times,
matrix(0.0,
nrow = length(census_times),
ncol = nrow(flow_matrix),
dimnames = list(NULL, c(rownames(flow_matrix)))))
# initialize the lna_param_vec for the measurement process
param_vec <- parmat[1,]
# initialize the latent path
if (mcmc_restart) {
# recompute the data log likelihood
data_log_lik_prop <- NULL
try({
census_latent_path(
path = path$latent_path,
census_path = censusmat,
census_inds = census_indices,
event_inds = event_inds,
flow_matrix = flow_matrix,
do_prevalence = do_prevalence,
parmat = parmat,
initdist_inds = initdist_inds,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers
)
# evaluate the density of the incidence counts
evaluate_d_measure_LNA(
emitmat = emitmat,
obsmat = dat,
censusmat = censusmat,
measproc_indmat = measproc_indmat,
parameters = parmat,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
param_vec = param_vec,
d_meas_ptr = d_meas_pointer
)
# compute the data log likelihood
data_log_lik_prop <- sum(emitmat[, -1][measproc_indmat])
if (is.nan(data_log_lik_prop)) data_log_lik_prop <- -Inf
}, silent = TRUE)
if (is.null(data_log_lik_prop)) {
stop("Restart attempted with data log likelihood of negative infinity.")
} else {
path$data_log_lik <- data_log_lik_prop
}
} else {
if(method == "lna") {
inits = initialize_lna(
dat = dat,
parmat = parmat,
param_blocks = param_blocks,
tparam = tparam,
censusmat = censusmat,
emitmat = emitmat,
stoich_matrix = stoich_matrix,
census_times = census_times,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
do_prevalence = do_prevalence,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
initialization_attempts = initialization_attempts,
step_size = step_size,
initdist_objects = initdist_objects,
ess_warmup = approx_warmup)
} else {
inits = initialize_ode(
dat = dat,
parmat = parmat,
param_blocks = param_blocks,
tparam = tparam,
censusmat = censusmat,
emitmat = emitmat,
stoich_matrix = stoich_matrix,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
census_times = census_times,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
do_prevalence = do_prevalence,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
initialization_attempts = initialization_attempts,
step_size = step_size,
initdist_objects = initdist_objects)
}
# grab the initial path, param_blocks, initdist_objects, and tparam
path <- inits$path
param_blocks <- inits$param_blocks
initdist_objects <- inits$initdist_objects
tparam <- inits$tparam
}
if(method == "lna") {
# object for proposing new stochastic perturbations
draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
copy_mat(draws_prop, path$draws)
# instatiate matrix for elliptical slice sampling draws
ess_draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
copy_mat(ess_draws_prop, path$draws)
# elliptical slice sampling MCMC record
for(s in seq_along(lna_ess_schedule)) {
lna_ess_schedule[[s]]$steps <- rep(1.0, lna_ess_control$n_updates)
lna_ess_schedule[[s]]$angles <- rep(0.0, lna_ess_control$n_updates)
}
}
# warmup the LNA, initial conditions, or time-varying parameters
for(warmup in seq_len(ess_warmup)) {
if(method == "lna") {
lna_update(
path = path,
dat = dat,
iter = 0,
parmat = parmat,
lna_ess_schedule = lna_ess_schedule,
lna_ess_control = lna_ess_control,
initdist_objects = initdist_objects,
tparam = tparam,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
joint_initdist_update = joint_initdist_update,
step_size = step_size
)
}
if(!fixed_inits && !joint_initdist_update) {
initdist_update(
path = path,
dat = dat,
iter = 0,
parmat = parmat,
initdist_objects = initdist_objects,
initdist_ess_control = initdist_ess_control,
tparam = tparam,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
step_size = step_size
)
}
if(!is.null(tparam)) {
tparam_update(
path = path,
dat = dat,
iter = 0,
parmat = parmat,
tparam = tparam,
tparam_ess_control = tparam_ess_control,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
step_size = step_size
)
}
}
# objects to store the paths and likelihood terms
mcmc_samples <-
list(data_log_lik = rep(0.0, n_samples),
params_log_prior = rep(0.0, n_samples),
parameter_samples_nat = matrix(0.0, nrow = n_samples, ncol = n_model_params,
dimnames = list(NULL, param_names_nat)),
parameter_samples_est = matrix(0.0, nrow = n_samples, ncol = n_model_params,
dimnames = list(NULL, param_names_est)),
latent_paths =
array(0.0, dim = c(length(census_times), 1 + n_rates, n_samples),
dimnames = list(NULL, c("time", rownames(flow_matrix)), NULL)))
if(method == "lna") {
# vector for saving the log-likelihood of the LNA draws
mcmc_samples$lna_log_lik <- rep(0.0, n_samples)
# matrix for saving LNA draws
mcmc_samples$lna_draws <-
array(0.0, dim = c(n_rates, length(census_times) - 1, n_samples),
dimnames = list(rownames(flow_matrix), NULL, NULL))
}
if(!fixed_inits) {
mcmc_samples$initdist_log_lik <- rep(0.0, n_samples)
mcmc_samples$initdist_samples <-
matrix(0.0, nrow = n_samples, ncol = n_compartments,
dimnames = list(NULL, initdist_names))
mcmc_samples$initdist_draws <-
vector("list", length = length(initdist_objects))
for(i in seq_along(initdist_objects)) {
mcmc_samples$initdist_draws[[i]] <-
matrix(0.0,
nrow = length(initdist_objects[[i]]$draws_cur),
ncol = n_samples)
}
}
if(!is.null(tparam)) {
mcmc_samples$tparam_log_lik <- rep(0.0, n_samples)
mcmc_samples$tparam_samples <-
array(0.0, dim = c(n_times, length(tparam), n_samples))
mcmc_samples$tparam_draws <-
vector("list", length = length(initdist_objects))
for(i in seq_along(tparam)) {
mcmc_samples$tparam_draws[[i]] <-
matrix(0.0,
nrow = tparam[[i]]$n_draws,
ncol = n_samples)
}
}
# record the initial parameter values
rec_ind <- 0
if(return_ess_rec) ess_rec_ind <- 0
# initialize the status file if status updates are required
if (print_progress) {
status_file <-
paste0(status_filename,
"_inference_status_",
as.numeric(Sys.time()),
".txt")
cat(
"Beginning MCMC",
file = status_file,
sep = "\n",
append = FALSE
)
}
# begin the MCMC
start.time <- Sys.time()
for (iter in seq_len(iterations)) {
# Sample new parameter values
block_order = sample.int(length(param_blocks))
for(ind in block_order) {
if(param_blocks[[ind]]$alg == "mvnmh") {
# save the covariance matrix if stopping adaptation
if (iter == min(max_adaptation, iterations)) {
param_blocks[[ind]]$sigma <-
param_blocks[[ind]]$mvnmh_objects$proposal_scaling *
param_blocks[[ind]]$kernel_cov
colnames(param_blocks[[ind]]$sigma) <-
rownames(param_blocks[[ind]]$sigma) <-
param_blocks[[ind]]$param_names_est
comp_chol(param_blocks[[ind]]$kernel_cov_chol,
param_blocks[[ind]]$sigma)
}
# print(paste("iter:", iter))
# print(paste("ind:", ind))
# print("param_blocks[[ind]]")
# print(param_blocks[[ind]])
# print("proposal_scaling:")
# print(param_blocks[[ind]]$mvnmh_objects$proposal_scaling)
# print("gain_factors")
# print(param_blocks[[ind]]$gain_factors[iter])
# sample new parameters
mvnmh_update(
param_blocks = param_blocks,
ind = ind,
iter = iter,
parmat = parmat,
dat = dat,
path = path,
pathmat_prop = pathmat_prop,
tparam = tparam,
census_times = census_times,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
censusmat = censusmat,
emitmat = emitmat,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
param_update_inds = param_update_inds,
initdist_inds = initdist_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
proc_pointer = proc_pointer,
d_meas_pointer = d_meas_pointer,
set_pars_pointer = set_pars_pointer,
do_prevalence = do_prevalence,
step_size = step_size,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V)
} else if(param_blocks[[ind]]$alg == "mvnss") {
# save the covariance matrix and bracket if stopping adaptation
if (iter == min(max_adaptation, iterations)) {
# covariance matrix
param_blocks[[ind]]$sigma <- param_blocks[[ind]]$kernel_cov
colnames(param_blocks[[ind]]$sigma) <-
rownames(param_blocks[[ind]]$sigma) <-
param_blocks[[ind]]$param_names_est
# cholesky
comp_chol(param_blocks[[ind]]$kernel_cov_chol,
param_blocks[[ind]]$sigma)
}
# sample new parameters
mvnss_update(
param_blocks = param_blocks,
ind = ind,
iter = iter,
parmat = parmat,
dat = dat,
path = path,
pathmat_prop = pathmat_prop,
tparam = tparam,
census_times = census_times,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
censusmat = censusmat,
emitmat = emitmat,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
param_update_inds = param_update_inds,
initdist_inds = initdist_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
proc_pointer = proc_pointer,
d_meas_pointer = d_meas_pointer,
set_pars_pointer = set_pars_pointer,
do_prevalence = do_prevalence,
step_size = step_size,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V)
}
}
# update the initial compartment volumes
if(!fixed_inits && !joint_initdist_update) {
initdist_update(
path = path,
dat = dat,
iter = iter,
parmat = parmat,
initdist_objects = initdist_objects,
initdist_ess_control = initdist_ess_control,
tparam = tparam,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
step_size = step_size
)
}
# update the tparam draws
if (!is.null(tparam)) {
tparam_update(
path = path,
dat = dat,
iter = iter,
parmat = parmat,
tparam = tparam,
tparam_ess_control = tparam_ess_control,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
step_size = step_size
)
}
# Update the path via elliptical slice sampling
if(method == "lna") {
lna_update(
path = path,
dat = dat,
iter = iter,
parmat = parmat,
lna_ess_schedule = lna_ess_schedule,
lna_ess_control = lna_ess_control,
initdist_objects = initdist_objects,
tparam = tparam,
pathmat_prop = pathmat_prop,
censusmat = censusmat,
draws_prop = draws_prop,
ess_draws_prop = ess_draws_prop,
emitmat = emitmat,
flow_matrix = flow_matrix,
stoich_matrix = stoich_matrix,
census_times = census_times,
forcing_inds = forcing_inds,
forcing_tcov_inds = forcing_tcov_inds,
forcings_out = forcings_out,
forcing_transfers = forcing_transfers,
param_vec = param_vec,
param_inds = param_inds,
const_inds = const_inds,
tcovar_inds = tcovar_inds,
initdist_inds = initdist_inds,
param_update_inds = param_update_inds,
census_indices = census_indices,
event_inds = event_inds,
measproc_indmat = measproc_indmat,
svd_d = svd_d,
svd_U = svd_U,
svd_V = svd_V,
proc_pointer = proc_pointer,
set_pars_pointer = set_pars_pointer,
d_meas_pointer = d_meas_pointer,
do_prevalence = do_prevalence,
joint_initdist_update = joint_initdist_update,
step_size = step_size
)
}
# Save the MCMC sample if called for in this iteration
if(record_sample && iter %% thinning_interval == 0) {
save_mcmc_sample(
mcmc_samples = mcmc_samples,
rec_ind = rec_ind,
path = path,
parmat = parmat,
param_blocks = param_blocks,
initdist_objects = initdist_objects,
tparam = tparam,
tparam_inds = tparam_inds,
method = method
)
if(return_ess_rec) {
save_ess_rec(
ess_record = ess_record,
ess_rec_ind = ess_rec_ind,
lna_ess_schedule = lna_ess_schedule,
tparam = tparam,
initdist_ess_control = initdist_ess_control
)
}
}
# check if the sample should be recorded
if(!record_sample && iter == (max_adaptation+1)) {
record_sample = TRUE
}
# print status messages if called for
if(print_progress && iter %% progress_interval == 0) {
cat(paste0("Iteration: ",iter),
file = status_file,
sep = "\n",
append = TRUE)
for(s in seq_along(param_blocks)) {
if(param_blocks[[s]]$alg == "mvnmh") {
cat(paste0("\t", "Parameter block: ", s),
paste0("\t", "\t", "Accepted proposals: ", param_blocks[[s]]$mvnmh_objects$acceptances),
paste0("\t", "\t", "Acceptance rate: ", param_blocks[[s]]$mvnmh_objects$acceptances / iter),
paste0("\t", "\t", "Proposal scaling: ", param_blocks[[s]]$mvnmh_objects$proposal_scaling),
file = status_file,
sep = "\n",
append = TRUE)
} else {
cat(paste0("\t", "Parameter block: ", s),
paste0("\t", "\t", "Contractions: ", param_blocks[[s]]$mvnss_objects$n_contractions - 0.5),
paste0("\t", "\t", "Expansions: ", param_blocks[[s]]$mvnss_objects$n_expansions - 0.5),
file = status_file,
sep = "\n",
append = TRUE)
}
}
}
}
# record the end time
end.time <- Sys.time()
# compile the results
stem_object$results <-
list(runtime = difftime(end.time, start.time, units = "hours"),
posterior = mcmc_samples)
if(return_ess_rec) stem_object$results$ess_record <- ess_record
# save inits for restart
stem_object$dynamics$parameters <- parmat[1, param_inds + 1]
if(!fixed_inits) stem_object$dynamics$initdist_params <- parmat[1, initdist_inds + 1]
if(!is.null(tparam)) stem_object$dynamics$tparam <- tparam
stem_object$restart <-
list(path = path,
param_blocks = param_blocks,
initdist_objects = initdist_objects,
tparam = tparam)
# return stem_object
return(stem_object)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.