
### minweke style test for Metropolis-Hastings with ODEs
## Procedure: alternate between the following
# 1) Update parameters via MH, 
# 2) Sample new data
# Compare sampled parameters with their priors.


popsize = 1e4 # population size

true_pars =
  c(R0     = 1.5,  # basic reproduction number
    mu_inv = 2,    # infectious period duration = 2 days
    rho    = 0.5,  # case detection rate
    phi    = 10)   # negative binomial overdispersion

# initialize model compartments and rates
strata <- NULL # no strata
compartments <- c("S", "I", "R")

# rates initialized as a list of rate lists
rates <-
  list(rate(rate = "beta * I", # individual level rate (unlumped)
            from = "S",        # source compartment
            to   = "I",        # destination compartment
            incidence = T),    # compute incidence of S2I transitions, required for simulating incidence data
       rate(rate = "mu",       # individual level rate
            from = "I",        # source compartment
            to   = "R",        # destination compartment
            incidence = TRUE)) # compute incidence of I2R transitions (not required for simulating data)

# list used for simulation/inference for the initial state, initial counts fixed.
# state initializer a list of stem_initializer lists.
state_initializer <-
    init_states = c(S = popsize-10, I = 10, R = 0), # must match compartment names
    fixed = T)) # initial state fixed for simulation, we'll change this later

# set the parameter values - must be a named vector
parameters =
  c(true_pars["R0"] / popsize / true_pars["mu_inv"], # R0 = beta * P / mu
names(parameters) <- c("beta", "mu", "rho", "phi")

# declare the initial time to be constant
constants <- c(t0 = 0)
t0 <- 0; tmax <- 40

# compile the model
dynamics <-
    rates = rates,
    tmax = tmax,
    parameters = parameters,
    state_initializer = state_initializer,
    compartments = compartments,
    constants = constants,
    compile_ode = T,   # compile ODE functions
    compile_rates = F, # compile MJP functions for Gillespie simulation
    compile_lna = T,   # compile LNA functions
    messages = F       # don't print messages

# list of emission distribution lists (analogous to rate specification)
emissions <-
  list(emission(meas_var = "S2I", # transition or compartment being measured (S->I transitions)
                distribution    = "negbinomial",         # emission distribution
                emission_params = c("phi", "S2I * rho"), # distribution pars, here overdispersion and mean
                incidence       = TRUE,                  # is the data incidence
                obstimes        = seq(1, tmax, by =1)))  # vector of observation times

# compile the measurement process
measurement_process <-
  stem_measure(emissions = emissions,
               dynamics  = dynamics,
               messages  = F)

# put it all together into a stochastic epidemic model object
stem_object <-
  make_stem(dynamics = dynamics,
            measurement_process = measurement_process)

# simulate dataset
sim_ode <- simulate_stem(stem_object = stem_object, method = "ode")

# recompile the stem object
state_initializer <-
      init_states = c(S = popsize-10, I = 10, R = 0), # must match compartment names
      fixed = FALSE,
      prior = c(popsize, 10, 0)/10)) # we now do inference on the initial compartment counts

dynamics <-
    rates = rates,
    tmax = tmax,
    parameters = parameters,
    state_initializer = state_initializer,
    compartments = compartments,
    constants = constants,
    compile_ode = T,   # compile ODE functions
    compile_rates = F, # compile MJP functions for Gillespie simulation
    compile_lna = T,   # compile LNA functions
    messages = F       # don't print messages

measurement_process <-
  stem_measure(emissions = emissions,
               dynamics = dynamics,
               data = sim_ode$datasets[[1]])

stem_object <-
  make_stem(dynamics = dynamics, 
            measurement_process = measurement_process)

# function to take params_nat and return params_est
to_estimation_scale = function(params_nat) {
  c(log(params_nat[1] * popsize / params_nat[2] - 1), # (beta,mu,N) -> log(R0-1)
    log(params_nat[2]),                     # mu -> log(mu)
    logit(params_nat[3]),                   # rho -> logit(rho)
    log(params_nat[4]))                     # phi -> log(phi)

# function to take params_est and return params_nat
from_estimation_scale = function(params_est) {
  c(exp(log(exp(params_est[1])+1) + params_est[2] - log(popsize)), # (log(R0), log(mu), N) -> beta = exp(log(R0) + log(mu) - log(N))
    exp(params_est[2]), # log(mu) -> mu
    expit(params_est[3]), # logit(rho) -> rho
    exp(params_est[4])) # log(phi) -> phi

# calculate the log prior density. note the jacobian for phi
logprior =
  function(params_est) {
    sum(dnorm(params_est[1], 0, 0.5, log = TRUE),
        dnorm(params_est[2], -0.7, 0.35, log = TRUE),
        dnorm(params_est[3], 0, 0.5, log = TRUE),
        dnorm(params_est[4], log(10), 1, log = T))

# return all three functions in a list
priors <- list(logprior = logprior,
               to_estimation_scale = to_estimation_scale,
               from_estimation_scale = from_estimation_scale)

#' We now specify the MCMC transition kernel. In this simple example, we'll update
#' the model hyperparameters using a multivariate Metropolis algorithm. We'll tune
#' the algorithm using a global adaptive scheme (algorithm 4 in Andrieu and Thoms). We'll also initialize the parameters at random values, which is done by replacing the vector of parameters in the stem object with a function that returns a named vector of parameters.
## ----mcmc_kern, echo = TRUE----------------------------------------------
# specify the initial proposal covariance matrix with row and column names
# corresponding to parameters on their estimation scales

# par_initializer = function() {
# priors$from_estimation_scale(priors$to_estimation_scale(parameters) + rnorm(4, 0, 0.1))
# }

# specify the kernel
mcmc_kern <-
    parameter_blocks = 
        pars_nat = c("beta", "mu", "rho", "phi"),
        pars_est = c("log_R0", "log_mu", "logit_rho", "log_phi"),
        priors = priors,
        # alg = "mvnss",
        alg = "mvnmh",
        sigma = diag(0.01, 4),
        # initializer = par_initializer,
        control = 
          # mvnss_control(stop_adaptation = 5e4))),
          mvnmh_control(stop_adaptation = 0))),
    lna_ess_control = lna_control(bracket_update_iter = 1e4,
                                  joint_initdist_update = FALSE),
    initdist_ess_control = initdist_control(bracket_update_iter = 1e4))

### start minweke
method = "ode"
iterations = 1e5
initialization_attempts = 500
ess_warmup = 50
thinning_interval = 100
return_adapt_rec = FALSE
return_ess_rec = TRUE
print_progress = 0
status_filename = NULL

# check that the data, dynamics and measurement process are all supplied
if(is.null(stem_object$measurement_process$data) ||
   is.null(stem_object$dynamics) ||
   is.null(stem_object$measurement_process)) {
  stop("The dataset, dynamics, and measurement process must all be specified.")

# check that the appropriate model object is compiled
if(method == "lna") {
  if(is.null(stem_object$dynamics$lna_pointers)) {
    stop("LNA is not compiled.")
  if(is.null(status_filename)) status_filename <- "LNA"
} else if(method == "ode") {
  if(is.null(stem_object$dynamics$ode_pointers)) {
    stop("ODE is not compiled.")
  if(is.null(status_filename)) status_filename <- "ODE"

# if the MCMC is being restarted, save the existing results
mcmc_restart <- !is.null(stem_object$restart)

# grab time-varying parameters
if(mcmc_restart) {
  path   <- stem_object$restart$path
  tparam <- stem_object$restart$tparam
} else {
  path   <- NULL
  tparam <- stem_object$dynamics$tparam

# grab parameter names
parameters      <- stem_object$dynamics$parameters
param_names_nat <- names(stem_object$dynamics$param_codes)[!grepl("_0", names(stem_object$dynamics$param_codes))]
initdist_names  <- names(stem_object$dynamics$param_codes)[grepl("_0", names(stem_object$dynamics$param_codes))]
param_names_est <- c(sapply(mcmc_kern$parameter_blocks, function(x) x$pars_est))
n_model_params  <- length(param_names_est)

# unpack mcmc kernel
param_blocks         <- mcmc_kern$parameter_blocks
lna_ess_control      <- mcmc_kern$lna_ess_control
initdist_ess_control <- mcmc_kern$initdist_ess_control
tparam_ess_control   <- mcmc_kern$tparam_ess_control

# parameter codes
if(method == "lna") {
  n_pars_tot  <- length(stem_object$dynamics$lna_rates$lna_param_codes) 
  param_codes <- stem_object$dynamics$lna_rates$lna_param_codes        
} else {
  n_pars_tot  <- length(stem_object$dynamics$ode_rates$ode_param_codes) 
  param_codes <- stem_object$dynamics$ode_rates$ode_param_codes

# prepare param_blocks
param_blocks <-
    param_blocks = param_blocks, 
    parameters   = parameters,
    param_codes  = param_codes,
    iterations   = iterations,
    mcmc_restart = mcmc_restart)

# maximum number of adaptive iterations
max_adaptation <- max(sapply(param_blocks, function(x) x$control$stop_adaptation))

# number of posterior samples
n_samples <- 
  if(return_adapt_rec) {
    floor(iterations / thinning_interval)
  } else {
    floor((iterations - max_adaptation) / thinning_interval) 
n_ess_recs <- n_samples 
record_sample <- FALSE

# progress printing interval
if(print_progress != 0) {
  progress_interval <- print_progress
  print_progress <- TRUE
} else {
  progress_interval <- NULL
  print_progress <- FALSE

### Unpack stem_object-------------------------

# dynamics that are not method specific
censusmat   <- stem_object$measurement_process$censusmat
constants   <- stem_object$dynamics$constants
initializer <- stem_object$dynamics$initializer
fixed_inits <- stem_object$dynamics$fixed_inits
n_strata    <- stem_object$dynamics$n_strata
step_size   <- stem_object$dynamics$dynamics_args$step_size

# measurement process objects that are not method specific
measproc_indmat <- stem_object$measurement_process$measproc_indmat
dat             <- stem_object$measurement_process$obsmat
obstimes        <- stem_object$measurement_process$obstimes

# initialize the ess_record
ess_record <- list(lna_ess_record = NULL,
                   initdist_ess_record = NULL,
                   tparam_ess_record = NULL)

# dynamics and measurement process objects that are method specific
if(method == "lna") {
  # extract objects from dynamics
  flow_matrix      <- stem_object$dynamics$flow_matrix_lna
  n_compartments   <- ncol(flow_matrix)
  n_rates          <- nrow(flow_matrix)
  stoich_matrix    <- stem_object$dynamics$stoich_matrix_lna
  proc_pointer     <- stem_object$dynamics$lna_pointers$lna_ptr
  set_pars_pointer <- stem_object$dynamics$lna_pointers$set_lna_params_ptr
  do_prevalence    <- stem_object$measurement_process$lna_prevalence
  event_inds       <- stem_object$measurement_process$incidence_codes_lna
  initdist_inds    <- stem_object$dynamics$lna_initdist_inds
  approx_warmup    <- lna_ess_control$approx_warmup
  # measurement process
  d_meas_pointer   <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr
  # indices of parameters, constants, and time-varying covariates
  param_inds <- 
  const_inds <- 
    length(stem_object$dynamics$param_codes) + 
    seq_along(stem_object$dynamics$const_codes) - 1
  tcovar_inds <- 
    length(stem_object$dynamics$param_codes) + 
    length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1
  # should initial concentrations be updated jointly with the LNA path
  joint_initdist_update <- !fixed_inits & lna_ess_control$joint_initdist_update
  # objects for computing the SVD of the LNA diffusion matrix
  svd_U <- diag(0.0, n_rates)
  svd_V <- diag(0.0, n_rates)
  svd_d <- rep(0.0, n_rates)
  # grab the tparam indices and update scheme
  if(!is.null(tparam)) {
    tparam_inds <-
        sapply(tparam, function(x) x$tparam_name)]
  } else {
    tparam_inds <- NULL
  lna_ess_schedule <- 
      stem_object = stem_object,
      initializer = initializer,
      lna_ess_control = lna_ess_control)
  if(return_ess_rec) {
    ess_record$lna_ess_record <- 
      list(ess_steps  = 
                   dim = c(lna_ess_control$n_updates,
           ess_angles = 
                   dim = c(lna_ess_control$n_updates,
} else if(method == "ode") {
  # extract objects from dynamics
  flow_matrix         <- stem_object$dynamics$flow_matrix_ode
  n_compartments      <- ncol(flow_matrix)
  n_rates             <- nrow(flow_matrix)
  stoich_matrix       <- stem_object$dynamics$stoich_matrix_ode
  proc_pointer        <- stem_object$dynamics$ode_pointers$ode_ptr
  set_pars_pointer    <- stem_object$dynamics$ode_pointers$set_ode_params_ptr
  do_prevalence       <- stem_object$measurement_process$ode_prevalence
  event_inds          <- stem_object$measurement_process$incidence_codes_ode
  initdist_inds       <- stem_object$dynamics$ode_initdist_inds
  # measurement process
  d_meas_pointer <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr
  # indices of parameters, constants, and time-varying covariates
  param_inds <- 
  const_inds <- 
    length(stem_object$dynamics$param_codes) + 
    seq_along(stem_object$dynamics$const_codes) - 1
  tcovar_inds <- 
    length(stem_object$dynamics$param_codes) + 
    length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1
  # obviously, initial distribution is its own ESS update
  joint_initdist_update <- FALSE
  # grab the tparam indices
  if(!is.null(tparam)) {
    tparam_inds <-
        sapply(tparam, function(x) x$tparam_name)]
  } else {
    tparam_inds <- NULL
  # set some LNA objects to NULL
  svd_U <- NULL
  svd_V <- NULL
  svd_d <- NULL
  lna_ess_schedule <- NULL

### Initial distribution objects --------------------------------------------
if(n_strata == 1) {
  comp_size_vec <- constants["popsize"]
} else {
  comp_size_vec <- constants[paste0("popsize_", sapply(initializer,"[[","strata"))]

# list for initial compartment volume objects
if(mcmc_restart) {
  initdist_objects <- stem_object$restart$initdist_objects
} else {
  initdist_objects <- 
    prepare_initdist_objects(initializer = initializer,
                             param_codes = param_codes,
                             comp_size_vec = comp_size_vec)    

# initialize initdist draws
if(!fixed_inits) {
  # for recording the ess initdist updates
  if(!joint_initdist_update) {
    initdist_ess_control$steps  <- rep(1.0, initdist_ess_control$n_updates)
    initdist_ess_control$angles <- rep(0.0, initdist_ess_control$n_updates)
    initdist_ess_control$angle_mean  <- 0.0
    initdist_ess_control$angle_resid <- 0.0
    initdist_ess_control$angle_var   <- 0.0
  if(return_ess_rec) {
    if(method != "lna" | !joint_initdist_update) {
      ess_record$initdist_ess_record <- 
        list(ess_steps = 
                      nrow = initdist_ess_control$n_updates, 
                      ncol = n_ess_recs),
             ess_angles = 
                      nrow = initdist_ess_control$n_updates, 
                      ncol = n_ess_recs))

# vector of census times
census_times <- 
                stem_object$dynamics$tcovar[, 1],
                    by = stem_object$dynamics$timestep),
census_indices <- unique(c(0, findInterval(obstimes, census_times) - 1))
n_times        <- length(census_times)

# make sure no times are less than t0
if(any(obstimes < stem_object$dynamics$t0)) {
  stop("Cannot have observations before time t0.")

if(any(stem_object$dynamics$tcovar[,1] < stem_object$dynamics$t0)) {
  stop("Cannot have time-varying covariates specified before time t0.")

### Set up parameter objects ---------------------------------
parmat <-  
         nrow = n_times,
         ncol = n_pars_tot,
         dimnames = list(NULL, names(param_codes)))

# insert parameters into the parameter matrix
insert_params(parmat       = parmat, 
              param_blocks = param_blocks, 
              nat          = TRUE,
              prop         = FALSE,
              rowind       = 0)

# insert initial conditions
insert_initdist(parmat           = parmat,
                initdist_objects = initdist_objects, 
                prop             = FALSE,
                rowind           = 0, 
                mcmc_rec         = FALSE)

# insert the constants
parmat[, const_inds + 1] <-  
         nrow = nrow(parmat),
         ncol = length(const_inds), byrow = T)

# generate forcing indices and other objects
forcing_inds <- rep(FALSE, length(census_times))

# insert time varying covariates
if(!is.null(stem_object$dynamics$tcovar)) {
  tcovar_rowinds <-  
    findInterval(census_times, stem_object$dynamics$tcovar[, 1],
                 left.open = FALSE, all.inside = TRUE)
  parmat[, tcovar_inds + 1] <-  
    stem_object$dynamics$tcovar[tcovar_rowinds, -1]
  # zero out forcings if necessary
  if(!is.null(stem_object$dynamics$forcings)) {
    # get the forcing indices (supplied in the original tcovar matrix)
    for(f in seq_along(stem_object$dynamics$forcings)) {
      forcing_inds <-  
        forcing_inds | 
        stem_object$dynamics$tcovar[stem_object$dynamics$tcovar[,1] %in% census_times, 
                                    stem_object$dynamics$forcings[[f]]$tcovar_name] != 0
    zero_inds <- !forcing_inds
    # zero out the tcovar elements corresponding to times with no forcings
    for(l in seq_along(stem_object$dynamics$dynamics_args$forcings)) {
      parmat[zero_inds, stem_object$dynamics$dynamics_args$forcings[[l]]$tcovar_name] <- 0

# get indices for time-varying parameters
if (!is.null(tparam)) {
  # verify whether the mcmc is being restarted
  if (!mcmc_restart) {
    # generate the indices for updating the time-varying parameter and initialize the values
    for (s in seq_along(tparam)) {
      # can get rid of the values slot
      tparam[[s]]$values <- NULL
      # indices
      tparam[[s]]$col_ind <- param_codes[tparam[[s]]$tparam_name]
      tparam[[s]]$tpar_inds_R <- findInterval(census_times, tparam[[s]]$times, left.open = F)
      tparam[[s]]$tpar_inds_R[tparam[[s]]$tpar_inds_R == 0] <- 1
      tparam[[s]]$tpar_inds_Cpp <- tparam[[s]]$tpar_inds_R - 1
      # values
      tparam[[s]]$draws_cur  <- rnorm(tparam[[s]]$n_draws)
      tparam[[s]]$draws_prop <- rnorm(tparam[[s]]$n_draws)
      tparam[[s]]$draws_ess  <- rnorm(tparam[[s]]$n_draws)
      tparam[[s]]$log_lik <- sum(dnorm(tparam[[s]]$draws_cur, log = TRUE))
      # get values
      tparam[[s]]$tpar_cur <- 
          parameters = parmat[1,],
          draws = tparam[[s]]$draws_cur)[tparam[[s]]$tpar_inds_R] 
      # insert into the parameter matrix
      vec_2_mat(dest = parmat,
                orig = tparam[[s]]$tpar_cur,
                ind  = tparam[[s]]$col_ind)
      # bracket width
      tparam[[s]]$bracket_width <- tparam_ess_control$bracket_width
      # ess counters 
      tparam[[s]]$steps  <- rep(1.0, tparam_ess_control$n_updates)
      tparam[[s]]$angles <- rep(0.0, tparam_ess_control$n_updates)
      # for tuning the bracket
      tparam[[s]]$angle_mean  <- 0
      tparam[[s]]$angle_var   <- pi^2/3
      tparam[[s]]$angle_resid <- 0
  } else {
    # if restarting, just copy the values into the parameter matrices
    for (s in seq_along(tparam)) {
      # insert into the parameter matrix
      vec_2_mat(dest = parmat,
                orig = tparam[[s]]$tpar_cur,
                ind  = tparam[[s]]$col_ind)
  # check if tparam[[s]] depends on initial conditions
  tparam <- check_tpar_depends(tparam = tparam, 
                               initdist_objects = initdist_objects,
                               parmat = parmat)
  if(return_ess_rec) {
    ess_record$tparam_ess_record <- 
      list(ess_steps  = 
                   dim = c(tparam_ess_control$n_updates,
           ess_angles = 
                   dim = c(tparam_ess_control$n_updates,
} else {
  tparam <- NULL

# indices for when to update the parameters
param_update_inds <- rep(FALSE, length(census_times))
param_update_inds[1] <- TRUE

if(!is.null(stem_object$dynamics$tcovar)) {
  param_update_inds[census_times %in% stem_object$dynamics$tcovar[,1]] <- TRUE

if(!is.null(tparam)) {
  for(s in seq_along(tparam)) {
    param_update_inds[census_times %in% tparam[[s]]$times] <- TRUE

if(length(param_update_inds) == nrow(stem_object$dynamics$tcovar_changemat)) {
  param_update_inds <-  
    param_update_inds | apply(stem_object$dynamics$tcovar_changemat, 1, any)

# generate forcing objects
if(!is.null(stem_object$dynamics$forcings)) {
  forcings <- stem_object$dynamics$forcings
  # names and indices
  forcing_tcovars   <- sapply(forcings, function(x) x$tcovar_name)
  forcing_tcov_inds <- match(forcing_tcovars, colnames(parmat)) - 1
  forcing_events    <- c(sapply(forcings, function(x) paste0(x$from, "2", x$to)))
  # matrix indicating which compartments are involved in which forcings in and out
  forcings_out <-
           nrow = ncol(flow_matrix), 
           ncol = length(forcings),
           dimnames = list(colnames(flow_matrix), 
  forcing_transfers <- 
          dim = c(ncol(flow_matrix),
          dimnames = list(colnames(flow_matrix),
  for(s in seq_along(forcings)) {
    forcings_out[forcings[[s]]$from, s] <- 1
    for(t in seq_along(forcings[[s]]$from)) {
      forcing_transfers[forcings[[s]]$from[t], forcings[[s]]$from[t], s] <- -1
      forcing_transfers[forcings[[s]]$to[t], forcings[[s]]$from[t], s] <- 1
} else {
  forcing_tcovars   <- character(0L)
  forcing_tcov_inds <- integer(0L)
  forcing_events    <- character(0L)
  forcings_out      <- matrix(0.0, nrow = 0, ncol = 0)
  forcing_transfers <- array(0.0, dim = c(0,0,0))

# matrix in which to store the emission probabilities
emitmat <- cbind(dat[, 1, drop = F],
                        nrow = nrow(measproc_indmat),
                        ncol = ncol(measproc_indmat),
                        dimnames = list(NULL, colnames(measproc_indmat))))

pathmat_prop <- cbind(census_times,
                             nrow = length(census_times),
                             ncol = nrow(flow_matrix),
                             dimnames = list(NULL, c(rownames(flow_matrix)))))

# initialize the lna_param_vec for the measurement process
param_vec <- parmat[1,]

# initialize the latent path
if (mcmc_restart) {
  # recompute the data log likelihood
  data_log_lik_prop <- NULL
      path                = path$latent_path,
      census_path         = censusmat,
      census_inds         = census_indices,
      event_inds          = event_inds,
      flow_matrix         = flow_matrix,
      do_prevalence       = do_prevalence,
      parmat              = parmat,
      initdist_inds       = initdist_inds,
      forcing_inds        = forcing_inds,
      forcing_tcov_inds   = forcing_tcov_inds,
      forcings_out        = forcings_out,
      forcing_transfers   = forcing_transfers
    # evaluate the density of the incidence counts
      emitmat           = emitmat,
      obsmat            = dat,
      censusmat         = censusmat,
      measproc_indmat   = measproc_indmat,
      parameters        = parmat,
      param_inds        = param_inds,
      const_inds        = const_inds,
      tcovar_inds       = tcovar_inds,
      param_update_inds = param_update_inds,
      census_indices    = census_indices,
      param_vec         = param_vec,
      d_meas_ptr        = d_meas_pointer
    # compute the data log likelihood
    data_log_lik_prop <- sum(emitmat[, -1][measproc_indmat])
    if (is.nan(data_log_lik_prop)) data_log_lik_prop <- -Inf
  }, silent = TRUE)
  if (is.null(data_log_lik_prop)) {
    stop("Restart attempted with data log likelihood of negative infinity.")
  } else {
    path$data_log_lik <- data_log_lik_prop
} else {
  if(method == "lna") {
    inits = initialize_lna(
      dat                     = dat,
      parmat                  = parmat,
      param_blocks            = param_blocks,
      tparam                  = tparam,
      censusmat               = censusmat,
      emitmat                 = emitmat,
      stoich_matrix           = stoich_matrix,
      census_times            = census_times,
      proc_pointer            = proc_pointer,
      set_pars_pointer        = set_pars_pointer,
      d_meas_pointer          = d_meas_pointer,
      param_vec               = param_vec,
      param_inds              = param_inds,
      const_inds              = const_inds,
      tcovar_inds             = tcovar_inds,
      initdist_inds           = initdist_inds,
      param_update_inds       = param_update_inds,
      census_indices          = census_indices,
      event_inds              = event_inds,
      measproc_indmat         = measproc_indmat,
      do_prevalence           = do_prevalence,
      forcing_inds            = forcing_inds,
      forcing_tcov_inds       = forcing_tcov_inds,
      forcings_out            = forcings_out,
      forcing_transfers       = forcing_transfers,
      initialization_attempts = initialization_attempts,
      step_size               = step_size,
      initdist_objects        = initdist_objects,
      ess_warmup              = approx_warmup)      
  } else {
    inits = initialize_ode(
      dat                     = dat,
      parmat                  = parmat,
      param_blocks            = param_blocks,
      tparam                  = tparam,
      censusmat               = censusmat,
      emitmat                 = emitmat,
      stoich_matrix           = stoich_matrix,
      proc_pointer            = proc_pointer,
      set_pars_pointer        = set_pars_pointer,
      d_meas_pointer          = d_meas_pointer,
      census_times            = census_times,
      param_vec               = param_vec,
      param_inds              = param_inds,
      const_inds              = const_inds,
      tcovar_inds             = tcovar_inds,
      initdist_inds           = initdist_inds,
      param_update_inds       = param_update_inds,
      census_indices          = census_indices,
      event_inds              = event_inds,
      measproc_indmat         = measproc_indmat,
      do_prevalence           = do_prevalence,
      forcing_inds            = forcing_inds,
      forcing_tcov_inds       = forcing_tcov_inds,
      forcings_out            = forcings_out,
      forcing_transfers       = forcing_transfers,
      initialization_attempts = initialization_attempts,
      step_size               = step_size,
      initdist_objects        = initdist_objects)    
  # grab the initial path, param_blocks, initdist_objects, and tparam
  path             <- inits$path
  param_blocks     <- inits$param_blocks
  initdist_objects <- inits$initdist_objects
  tparam           <- inits$tparam

if(method == "lna") {
  # object for proposing new stochastic perturbations
  draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
  copy_mat(draws_prop, path$draws)
  # instatiate matrix for elliptical slice sampling draws
  ess_draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
  copy_mat(ess_draws_prop, path$draws)
  # elliptical slice sampling MCMC record
  for(s in seq_along(lna_ess_schedule)) {
    lna_ess_schedule[[s]]$steps  <- rep(1.0, lna_ess_control$n_updates)
    lna_ess_schedule[[s]]$angles <- rep(0.0, lna_ess_control$n_updates)

# warmup the LNA, initial conditions, or time-varying parameters
for(warmup in seq_len(ess_warmup)) {
  if(method == "lna") {      
      path                  = path,
      dat                   = dat,
      iter                  = 0,
      parmat                = parmat,
      lna_ess_schedule      = lna_ess_schedule,
      lna_ess_control       = lna_ess_control,
      initdist_objects      = initdist_objects,
      tparam                = tparam,
      pathmat_prop          = pathmat_prop,
      censusmat             = censusmat,
      draws_prop            = draws_prop,
      ess_draws_prop        = ess_draws_prop,
      emitmat               = emitmat,
      flow_matrix           = flow_matrix,
      stoich_matrix         = stoich_matrix,
      census_times          = census_times,
      forcing_inds          = forcing_inds,
      forcing_tcov_inds     = forcing_tcov_inds,
      forcings_out          = forcings_out,
      forcing_transfers     = forcing_transfers,
      param_vec             = param_vec,
      param_inds            = param_inds,
      const_inds            = const_inds,
      tcovar_inds           = tcovar_inds,
      initdist_inds         = initdist_inds,
      param_update_inds     = param_update_inds,
      census_indices        = census_indices,
      event_inds            = event_inds,
      measproc_indmat       = measproc_indmat,
      svd_d                 = svd_d,
      svd_U                 = svd_U,
      svd_V                 = svd_V,
      proc_pointer          = proc_pointer,
      set_pars_pointer      = set_pars_pointer,
      d_meas_pointer        = d_meas_pointer,
      do_prevalence         = do_prevalence,
      joint_initdist_update = joint_initdist_update,
      step_size             = step_size
  if(!fixed_inits && !joint_initdist_update) {
      path                 = path,
      dat                  = dat,
      iter                 = 0,
      parmat               = parmat,
      initdist_objects     = initdist_objects,
      initdist_ess_control = initdist_ess_control,
      tparam               = tparam,
      pathmat_prop         = pathmat_prop,
      censusmat            = censusmat,
      draws_prop           = draws_prop,
      ess_draws_prop       = ess_draws_prop,
      emitmat              = emitmat,
      flow_matrix          = flow_matrix,
      stoich_matrix        = stoich_matrix,
      census_times         = census_times,
      forcing_inds         = forcing_inds,
      forcing_tcov_inds    = forcing_tcov_inds,
      forcings_out         = forcings_out,
      forcing_transfers    = forcing_transfers,
      param_vec            = param_vec,
      param_inds           = param_inds,
      const_inds           = const_inds,
      tcovar_inds          = tcovar_inds,
      initdist_inds        = initdist_inds,
      param_update_inds    = param_update_inds,
      census_indices       = census_indices,
      event_inds           = event_inds,
      measproc_indmat      = measproc_indmat,
      svd_d                = svd_d,
      svd_U                = svd_U,
      svd_V                = svd_V,
      proc_pointer         = proc_pointer,
      set_pars_pointer     = set_pars_pointer,
      d_meas_pointer       = d_meas_pointer,
      do_prevalence        = do_prevalence,
      step_size            = step_size
  if(!is.null(tparam)) {
      path               = path, 
      dat                = dat,
      iter               = 0,
      parmat             = parmat,
      tparam             = tparam,
      tparam_ess_control = tparam_ess_control,
      pathmat_prop       = pathmat_prop,
      censusmat          = censusmat,
      draws_prop         = draws_prop,
      ess_draws_prop     = ess_draws_prop,
      emitmat            = emitmat,
      flow_matrix        = flow_matrix,
      stoich_matrix      = stoich_matrix,
      census_times       = census_times,
      forcing_inds       = forcing_inds,
      forcing_tcov_inds  = forcing_tcov_inds,
      forcings_out       = forcings_out,
      forcing_transfers  = forcing_transfers,
      param_vec          = param_vec,
      param_inds         = param_inds,
      const_inds         = const_inds,
      tcovar_inds        = tcovar_inds,
      initdist_inds      = initdist_inds,
      param_update_inds  = param_update_inds,
      census_indices     = census_indices,
      event_inds         = event_inds,
      measproc_indmat    = measproc_indmat,
      svd_d              = svd_d,
      svd_U              = svd_U,
      svd_V              = svd_V,
      proc_pointer       = proc_pointer,
      set_pars_pointer   = set_pars_pointer,
      d_meas_pointer     = d_meas_pointer,
      do_prevalence      = do_prevalence,
      step_size          = step_size

# objects to store the paths and likelihood terms
mcmc_samples <-  
  list(data_log_lik          = rep(0.0, n_samples),
       params_log_prior      = rep(0.0, n_samples),
       parameter_samples_nat = matrix(0.0, nrow = n_samples, ncol = n_model_params, 
                                      dimnames = list(NULL, param_names_nat)),
       parameter_samples_est = matrix(0.0, nrow = n_samples, ncol = n_model_params, 
                                      dimnames = list(NULL, param_names_est)),
       latent_paths = 
         array(0.0, dim = c(length(census_times), 1 + n_rates, n_samples), 
               dimnames = list(NULL, c("time", rownames(flow_matrix)), NULL)))

if(method == "lna") {
  # vector for saving the log-likelihood of the LNA draws
  mcmc_samples$lna_log_lik <- rep(0.0, n_samples)
  # matrix for saving LNA draws
  mcmc_samples$lna_draws <- 
    array(0.0, dim = c(n_rates, length(census_times) - 1, n_samples),
          dimnames = list(rownames(flow_matrix), NULL, NULL))

if(!fixed_inits) {
  mcmc_samples$initdist_log_lik <- rep(0.0, n_samples)
  mcmc_samples$initdist_samples <- 
    matrix(0.0, nrow = n_samples, ncol = n_compartments,
           dimnames = list(NULL, initdist_names))
  mcmc_samples$initdist_draws <- 
    vector("list", length = length(initdist_objects))
  for(i in seq_along(initdist_objects)) {
    mcmc_samples$initdist_draws[[i]] <- 
             nrow = length(initdist_objects[[i]]$draws_cur),
             ncol = n_samples)

if(!is.null(tparam)) {
  mcmc_samples$tparam_log_lik <- rep(0.0, n_samples)
  mcmc_samples$tparam_samples <- 
    array(0.0, dim = c(n_times, length(tparam), n_samples))
  mcmc_samples$tparam_draws <- 
    vector("list", length = length(initdist_objects))
  for(i in seq_along(tparam)) {
    mcmc_samples$tparam_draws[[i]] <-
             nrow = tparam[[i]]$n_draws,
             ncol = n_samples)

# record the initial parameter values
rec_ind <- 0
if(return_ess_rec) ess_rec_ind <- 0

# initialize the status file if status updates are required
if (print_progress) {
  status_file <- 
    "Beginning MCMC",
    file = status_file,
    sep = "\n",
    append = FALSE

# begin the MCMC
start.time = Sys.time()
keep_going = TRUE; iter = 0
while(keep_going & iter < iterations) {
  iter = iter + 1
  # Sample new parameter values
  # block_order = sample.int(length(param_blocks))
  # for(ind in block_order) {
  #   if(param_blocks[[ind]]$alg == "mvnmh") {
  #     # save the covariance matrix if stopping adaptation
  #     if (iter == min(max_adaptation, iterations)) {
  #       param_blocks[[ind]]$sigma <- 
  #         param_blocks[[ind]]$mvnmh_objects$proposal_scaling * 
  #         param_blocks[[ind]]$kernel_cov
  #       colnames(param_blocks[[ind]]$sigma) <- 
  #         rownames(param_blocks[[ind]]$sigma) <- 
  #         param_blocks[[ind]]$param_names_est
  #       comp_chol(param_blocks[[ind]]$kernel_cov_chol, 
  #                 param_blocks[[ind]]$sigma)
  #     }
  #     # sample new parameters        
  #     mvnmh_update(
  #       param_blocks      = param_blocks, 
  #       ind               = ind,
  #       iter              = iter,
  #       parmat            = parmat,
  #       dat               = dat,
  #       path              = path,
  #       pathmat_prop      = pathmat_prop,
  #       tparam            = tparam, 
  #       census_times      = census_times,
  #       flow_matrix       = flow_matrix,
  #       stoich_matrix     = stoich_matrix,
  #       censusmat         = censusmat,
  #       emitmat           = emitmat,
  #       param_vec         = param_vec, 
  #       param_inds        = param_inds, 
  #       const_inds        = const_inds, 
  #       tcovar_inds       = tcovar_inds, 
  #       param_update_inds = param_update_inds, 
  #       initdist_inds     = initdist_inds,
  #       census_indices    = census_indices,
  #       event_inds        = event_inds,
  #       measproc_indmat   = measproc_indmat,
  #       forcing_inds      = forcing_inds,
  #       forcing_tcov_inds = forcing_tcov_inds,
  #       forcings_out      = forcings_out,
  #       forcing_transfers = forcing_transfers,
  #       proc_pointer      = proc_pointer,
  #       d_meas_pointer    = d_meas_pointer,
  #       set_pars_pointer  = set_pars_pointer,
  #       do_prevalence     = do_prevalence,
  #       step_size         = step_size,
  #       svd_d             = svd_d,
  #       svd_U             = svd_U,
  #       svd_V             = svd_V)
  #   } else if(param_blocks[[ind]]$alg == "mvnss") {
  #     # save the covariance matrix and bracket if stopping adaptation
  #     if (iter == min(max_adaptation, iterations)) {
  #       # covariance matrix
  #       param_blocks[[ind]]$sigma <- param_blocks[[ind]]$kernel_cov
  #       colnames(param_blocks[[ind]]$sigma) <- 
  #         rownames(param_blocks[[ind]]$sigma) <- 
  #         param_blocks[[ind]]$param_names_est
  #       # cholesky
  #       comp_chol(param_blocks[[ind]]$kernel_cov_chol, 
  #                 param_blocks[[ind]]$sigma)
  #     }
  #     # sample new parameters        
  #     mvnss_update(
  #       param_blocks      = param_blocks, 
  #       ind               = ind,
  #       iter              = iter,
  #       parmat            = parmat,
  #       dat               = dat,
  #       path              = path,
  #       pathmat_prop      = pathmat_prop,
  #       tparam            = tparam,
  #       census_times      = census_times,
  #       flow_matrix       = flow_matrix,
  #       stoich_matrix     = stoich_matrix,
  #       censusmat         = censusmat,
  #       emitmat           = emitmat,
  #       param_vec         = param_vec,
  #       param_inds        = param_inds, 
  #       const_inds        = const_inds, 
  #       tcovar_inds       = tcovar_inds, 
  #       param_update_inds = param_update_inds, 
  #       initdist_inds     = initdist_inds,
  #       census_indices    = census_indices,
  #       event_inds        = event_inds,
  #       measproc_indmat   = measproc_indmat,
  #       forcing_inds      = forcing_inds,
  #       forcing_tcov_inds = forcing_tcov_inds,
  #       forcings_out      = forcings_out,
  #       forcing_transfers = forcing_transfers,
  #       proc_pointer      = proc_pointer,
  #       d_meas_pointer    = d_meas_pointer,
  #       set_pars_pointer  = set_pars_pointer,
  #       do_prevalence     = do_prevalence,
  #       step_size         = step_size,
  #       svd_d             = svd_d,
  #       svd_U             = svd_U,
  #       svd_V             = svd_V)
  #   }
  # }
  # update the initial compartment volumes
  if(!fixed_inits && !joint_initdist_update) {
      path                 = path,
      dat                  = dat,
      iter                 = iter,
      parmat               = parmat,
      initdist_objects     = initdist_objects,
      initdist_ess_control = initdist_ess_control,
      tparam               = tparam,
      pathmat_prop         = pathmat_prop,
      censusmat            = censusmat,
      draws_prop           = draws_prop,
      ess_draws_prop       = ess_draws_prop,
      emitmat              = emitmat,
      flow_matrix          = flow_matrix,
      stoich_matrix        = stoich_matrix,
      census_times         = census_times,
      forcing_inds         = forcing_inds,
      forcing_tcov_inds    = forcing_tcov_inds,
      forcings_out         = forcings_out,
      forcing_transfers    = forcing_transfers,
      param_vec            = param_vec,
      param_inds           = param_inds,
      const_inds           = const_inds,
      tcovar_inds          = tcovar_inds,
      initdist_inds        = initdist_inds,
      param_update_inds    = param_update_inds,
      census_indices       = census_indices,
      event_inds           = event_inds,
      measproc_indmat      = measproc_indmat,
      svd_d                = svd_d,
      svd_U                = svd_U,
      svd_V                = svd_V,
      proc_pointer         = proc_pointer,
      set_pars_pointer     = set_pars_pointer,
      d_meas_pointer       = d_meas_pointer,
      do_prevalence        = do_prevalence,
      step_size            = step_size
  # update the tparam draws
  if (!is.null(tparam)) {
      path               = path, 
      dat                = dat,
      iter               = iter,
      parmat             = parmat,
      tparam             = tparam,
      tparam_ess_control = tparam_ess_control,
      pathmat_prop       = pathmat_prop,
      censusmat          = censusmat,
      draws_prop         = draws_prop,
      ess_draws_prop     = ess_draws_prop,
      emitmat            = emitmat,
      flow_matrix        = flow_matrix,
      stoich_matrix      = stoich_matrix,
      census_times       = census_times,
      forcing_inds       = forcing_inds,
      forcing_tcov_inds  = forcing_tcov_inds,
      forcings_out       = forcings_out,
      forcing_transfers  = forcing_transfers,
      param_vec          = param_vec,
      param_inds         = param_inds,
      const_inds         = const_inds,
      tcovar_inds        = tcovar_inds,
      initdist_inds      = initdist_inds,
      param_update_inds  = param_update_inds,
      census_indices     = census_indices,
      event_inds         = event_inds,
      measproc_indmat    = measproc_indmat,
      svd_d              = svd_d,
      svd_U              = svd_U,
      svd_V              = svd_V,
      proc_pointer       = proc_pointer,
      set_pars_pointer   = set_pars_pointer,
      d_meas_pointer     = d_meas_pointer,
      do_prevalence      = do_prevalence,
      step_size          = step_size
  # Update the path via elliptical slice sampling
  if(method == "lna") {
      path                  = path,
      dat                   = dat,
      iter                  = iter,
      parmat                = parmat,
      lna_ess_schedule      = lna_ess_schedule,
      lna_ess_control       = lna_ess_control,
      initdist_objects      = initdist_objects,
      tparam                = tparam,
      pathmat_prop          = pathmat_prop,
      censusmat             = censusmat,
      draws_prop            = draws_prop,
      ess_draws_prop        = ess_draws_prop,
      emitmat               = emitmat,
      flow_matrix           = flow_matrix,
      stoich_matrix         = stoich_matrix,
      census_times          = census_times,
      forcing_inds          = forcing_inds,
      forcing_tcov_inds     = forcing_tcov_inds,
      forcings_out          = forcings_out,
      forcing_transfers     = forcing_transfers,
      param_vec             = param_vec,
      param_inds            = param_inds,
      const_inds            = const_inds,
      tcovar_inds           = tcovar_inds,
      initdist_inds         = initdist_inds,
      param_update_inds     = param_update_inds,
      census_indices        = census_indices,
      event_inds            = event_inds,
      measproc_indmat       = measproc_indmat,
      svd_d                 = svd_d,
      svd_U                 = svd_U,
      svd_V                 = svd_V,
      proc_pointer          = proc_pointer,
      set_pars_pointer      = set_pars_pointer,
      d_meas_pointer        = d_meas_pointer,
      do_prevalence         = do_prevalence,
      joint_initdist_update = joint_initdist_update,
      step_size             = step_size
  if(abs(path$data_log_lik) > 1e3 | is.nan(abs(path$data_log_lik))) {
    keep_going = FALSE
    post = FALSE; pre = TRUE; rec = FALSE
  # simulate a new dataset
  dat[,2] <- 
            mu = path$latent_path[-1,2] * parmat[1,"rho"],
            size = parmat[1,"phi"])
  # update the data log likelihood
  path$data_log_lik <- 
                mu = path$latent_path[-1,2] * parmat[1,"rho"],
                size = parmat[1,"phi"],
                log = TRUE))
  if(abs(path$data_log_lik) > 1e3 | is.nan(abs(path$data_log_lik))) {
    keep_going = FALSE
    post = T; pre = FALSE; rec = FALSE
  # check if the sample should be recorded
  if(!record_sample && iter == (max_adaptation+1)) {
    record_sample = TRUE
  # Save the MCMC sample if called for in this iteration
  if(record_sample && iter %% thinning_interval == 0) {
      mcmc_samples     = mcmc_samples,
      rec_ind          = rec_ind,
      path             = path,
      parmat           = parmat,
      param_blocks     = param_blocks,
      initdist_objects = initdist_objects,
      tparam           = tparam,
      tparam_inds      = tparam_inds,
      method           = method
    if(abs(mcmc_samples$data_log_lik[rec_ind]) > 1e3 | 
       is.nan(abs(mcmc_samples$data_log_lik[rec_ind]))) {
      keep_going = FALSE
      post = T; pre = FALSE; rec = T
    if(return_ess_rec) {
        ess_record            = ess_record,
        ess_rec_ind           = ess_rec_ind,
        lna_ess_schedule      = lna_ess_schedule,
        tparam                = tparam,
        initdist_ess_control  = initdist_ess_control
  # print status messages if called for
  if(print_progress && iter %% progress_interval == 0) {
    cat(paste0("Iteration: ",iter),
        file = status_file,
        sep = "\n",
        append = TRUE)
    for(b in seq_along(param_blocks)) {
      if(param_blocks[[b]]$alg == "mvnmh") {
        cat(paste0("\t", "Parameter block: ", b),
            paste0("\t", "\t", "Acceptance rate: ", param_blocks[[b]]$mvnmh_objects$acceptances),
            paste0("\t", "\t", "Proposal scaling: ", param_blocks[[b]]$mvnmh_objects$proposal_scaling),
            file = status_file,
            sep = "\n",
            append = TRUE)    
      } else {
        cat(paste0("\t", "Parameter block: ", b),
            paste0("\t", "\t", "Contractions: ", param_blocks[[b]]$mvnss_objects$n_contractions),
            paste0("\t", "\t", "Expansions: ", param_blocks[[b]]$mvnss_objects$n_expansions),
            file = status_file,
            sep = "\n",
            append = TRUE)    

# record the end time
end.time = Sys.time()

# compile the results
stem_object$results = 
  list(runtime = difftime(end.time, start.time, units = "hours"),
       posterior = mcmc_samples)

if(return_ess_rec) stem_object$results$ess_record = ess_record

# save inits for restart
stem_object$dynamics$parameters <- parmat[1, param_inds + 1]
if(!fixed_inits) stem_object$dynamics$initdist_params <- parmat[1, initdist_inds + 1]
if(!is.null(tparam)) stem_object$dynamics$tparam <- tparam

stem_object$restart = 
  list(path                 = path, 
       param_blocks         = param_blocks,
       initdist_objects     = initdist_objects,
       tparam               = tparam)

# Plots
hist(mcmc_samples$initdist_draws[[1]][1,], 30, freq = FALSE)
lines(seq(-3,3,by=0.1), dnorm(seq(-3,3,by=0.1)))

hist(mcmc_samples$initdist_draws[[1]][2,], 30, freq = FALSE)
lines(seq(-3,3,by=0.1), dnorm(seq(-3,3,by=0.1)))
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.