.development_files/examples_stemr_0.1_ebola_mod/sierra_leone.R

library(stemr)
library(extraDistr)
library(foreach)
library(doParallel)
library(doRNG)

# Data and population size ------------------------------------------------

popsize <- 7.1e6
log_popsize <- log(popsize)

dat <- 
      cbind(
            time = 
                  1:55,
            cases = 
                  c(0,0,3,18,36,71,60,54,92,112,76,103,149,145,131,178,201,
                    224,240,418,410,404,428,543,558,617,614,533,525,646,451,
                    462,411,507,397,330,172,133,116,119,120,99,78,107,58,48,
                    38,31,19,12,8,12,0,0,0)
      )

# no strata the stemr object --------------------------------------------------
set.seed(12511)
strata <- NULL
compartments <- c("S", "E", "I", "R")
rates <- list(rate("(alpha + beta * I) * (S - effpop)", "S", "E", lumped = TRUE, incidence = T),
              rate("omega", "E", "I", incidence = T),
              rate("mu", "I", "R", incidence = T))
state_initializer <- list(stem_initializer(c(S = popsize-30, E = 15, I = 10, R = 5), fixed = F, prior = c(popsize-30, 15, 10, 5)))
adjacency <- NULL
tcovar <- NULL
parameters = c(alpha = 0.3 / (popsize - 7.05e6), beta = 1.5 / (popsize - 7.05e6), omega = 1, mu = 1, rho = 0.75, phi = 10, effpop = 7.05e6)
constants <- c(t0 = 0)
t0 <- 0; tmax <- nrow(dat);

dynamics <-
      stem_dynamics(
            rates = rates,
            tmax = tmax,
            timestep = NULL,
            parameters = parameters,
            state_initializer = state_initializer,
            compartments = compartments,
            constants = constants,
            strata = strata,
            adjacency = adjacency,
            tcovar = tcovar,
            messages = T,
            compile_ode = T,
            compile_rates = T,
            compile_lna = T,
            rtol = 1e-6,
            atol = 1e-6,
            step_size = 1e-6
      )

emissions <- list(emission("cases", "negbinomial", c("phi","E2I * rho"), incidence = TRUE, obstimes = seq(1, tmax, by =1)))

measurement_process <- stem_measure(data = dat, emissions = emissions, dynamics = dynamics, messages = T)

stem_object <- stem(dynamics = dynamics, measurement_process = measurement_process)

#### initialize the inference
# {alpha, beta, omega, mu, rho, phi, Neff} 
# -> {log(Rext), log(Reff-1) + log(rho * Neff), log(omega / mu), log(1/mu), logit(rho), log(phi), log(rho * Neff)}
to_estimation_scale <- 
      function (params_nat) {
            
            l_effpop     <- log(popsize - params_nat[7])
            l_Neff_x_rho <- l_effpop + log(params_nat[5]) 
            l_infecdur   <- -log(params_nat[4])
            l_Reff_m1    <- log(exp(log(params_nat[2]) + l_effpop + l_infecdur)-1)
            
            return(
                  c(
                        log(params_nat[1]) + log(1000),
                        l_Reff_m1 + l_Neff_x_rho,
                        log(params_nat[3]) + l_infecdur, 
                        l_infecdur,
                        logit(params_nat[5]),
                        log(params_nat[6]),
                        l_Neff_x_rho
                  )
            )
      }

from_estimation_scale <- 
      function (params_est) {
            
            rho <- expit(params_est[5])
            l_effpop <- params_est[7] - log(rho)
            
            return(c(
                  exp(params_est[1] - log(1000)),
                  exp(log(exp(params_est[2] - params_est[7])+1) - l_effpop - params_est[4]),
                  exp(params_est[3] - params_est[4]),
                  exp(-params_est[4]),
                  rho,
                  exp(params_est[6]),
                  popsize - exp(l_effpop)
            ))
      }

## Priors
priors <- list(prior_density =
                     function(params_nat, params_est) {
                           
                           l_effpop <- params_est[7] - log(expit(params_est[5]))
                           l_Reff_m_1 <- params_est[2] - params_est[7]
                           
                           sum(
                                 dexp(exp(params_est[1]), 40, log = TRUE) + params_est[1], 
                                 dnorm(l_Reff_m_1, log(0.5), 1.08, log = TRUE),
                                 dnorm(params_est[3], 0, 0.3, log = TRUE),
                                 dnorm(params_est[4], 0, 0.3, log = TRUE),
                                 dnorm(params_est[5], 0.85, 0.75, log = TRUE),
                                 dexp(exp(params_est[6]), 0.69, log = TRUE) + params_est[6],
                                 dnorm(l_effpop, 10.7, 0.62, log = TRUE)
                           )
                     },
               to_estimation_scale   = to_estimation_scale,
               from_estimation_scale = from_estimation_scale)

covmat_names <- c(
      "log_Reff_ext",
      "log_Reff_m_1_o",
      "log_omega_d_mu",
      "log_carriage_dur",
      "logit_rho",
      "log_phi",
      "log_effpop_o"
)
covmat <- diag(0.01, length(parameters))
rownames(covmat) <- colnames(covmat) <- covmat_names

mcmc_kernel <-
      kernel(
          method = "mvnss",
          sigma = covmat,
          scale_constant = 0.5,
          scale_cooling = 0.7,
          stop_adaptation = 5e4,
          step_size = 0.5,
          nugget = 1e-5, 
          harss_warmup = 0,
          mvnss_setting_list = 
              mvnss_settings(n_mvnss_updates = 1, 
                             initial_bracket_width = 0.5,
                             bracket_limits = c(0.001, 5),
                             nugget_cooling = 0.99, 
                             nugget_step_size = 0.001),
            messages = FALSE
      )

stem_object$dynamics$parameters <- function() {
      setNames(from_estimation_scale(to_estimation_scale(parameters) + rnorm(length(parameters), 0, 0.1)),
               names(parameters))
}

registerDoParallel(3)

# results <- foreach(chain = 1:5,
#                    .packages=c("stemr"),
#                    .options.RNG = 52787,
#                    .export = ls(all.names = T)) %dorng% {
#                          
#                          chain_res <- stem_inference(stem_object = stem_object,
#                                                      method = "ode",
#                                                      iterations = 1.5e5,
#                                                      thin_params = 50,
#                                                      thin_latent_proc = 50,
#                                                      initialization_attempts = 500,
#                                                      priors = priors,
#                                                      mcmc_kernel = mcmc_kernel,
#                                                      t0_kernel = t0_kernel,
#                                                      ess_args = ess_settings(n_ess_updates = 1,
#                                                                              ess_warmup = 100, 
#                                                                              initdist_bracket_width = 2*pi,
#                                                                              initdist_bracket_update = 5e3,
#                                                                              lna_bracket_width = 2*pi,
#                                                                              lna_bracket_update = 5e3,
#                                                                              joint_strata_update = FALSE),
#                                                      print_progress = 5e3,
#                                                      messages = F)
#                          return(chain_res)
#                    }
# 
# save(results, file = paste0("sierra_leone_ODE.Rdata"))
# 
# # grab the initial covariance matrix
# covs <- sapply(results, function(x) cov(x$stem_settings$mcmc_kernel$sigma))
# covmat <- matrix(rowMeans(covs), length(parameters), length(parameters), dimnames = list(covmat_names, covmat_names))
# 
# mcmc_kernel <-
#       kernel(
#           method = "mvnss",
#           sigma = covmat,
#           scale_constant = 0.5,
#           scale_cooling = 0.7,
#           stop_adaptation = 5e4,
#           step_size = 0.5,
#           nugget = 1e-5, 
#           harss_warmup = 0,
#           mvnss_setting_list = 
#               mvnss_settings(n_mvnss_updates = 1, 
#                              initial_bracket_width = 0.5,
#                              bracket_limits = c(0.001, 5),
#                              nugget_cooling = 0.99, 
#                              nugget_step_size = 0.001),
#             messages = FALSE
#       )
# 
# # grab the initial parameters and compartment volumes
# init_pars <- lapply(results, function(x) x$dynamics$parameters)
# init_states <- lapply(results, function(x) x$dynamics$initdist_params)
# 
# rm(results)
# 
# registerDoParallel(5)

results <- foreach(chain = 1:5,
                   .packages="stemr",
                   .options.RNG = 52787,
                   .export = ls(all.names = T)) %dorng% {
                         
                         # stem_object$dynamics$parameters <- init_pars[[chain]]
                         # stem_object$dynamics$initdist_params <- init_states[[chain]]
                         
                         chain_res <- stem_inference(stem_object = stem_object,
                                                     method = "lna",
                                                     iterations = 1.5e5,
                                                     thin_params = 50,
                                                     thin_latent_proc = 50,
                                                     initialization_attempts = 500,
                                                     priors = priors,
                                                     mcmc_kernel = mcmc_kernel,
                                                     t0_kernel = t0_kernel,
                                                     ess_args = ess_settings(n_ess_updates = 1,
                                                                             ess_warmup = 100, 
                                                                             initdist_bracket_width = 2*pi,
                                                                             initdist_bracket_update = 5e3,
                                                                             lna_bracket_width = 2*pi,
                                                                             lna_bracket_update = 5e3,
                                                                             joint_strata_update = FALSE),  
                                                     messages = F,
                                                     print_progress = 5e3)
                         return(chain_res)
                   }

save(results, file = paste0("sierra_leone_LNA.Rdata"))
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.