.development_files/examples_stemr_0.1_ebola_mod/guinea.R

library(stemr)
library(extraDistr)
library(foreach)
library(doParallel)
library(doRNG)

# Data and population size ------------------------------------------------
stopImplicitCluster()
rm(list=ls())
popsize <- 11.8e6
log_popsize <- log(popsize)

dat <- 
      cbind(
            time = 
                  1:73,
            cases = 
                  c(2,0,0,5,3,5,3,5,9,8,27,17,28,27,27,16,10,12,11,8,27,50,28,
                    34,18,24,4,6,8,43,27,21,44,78,104,105,95,83,130,105,143,115,
                    95,95,162,100,170,93,122,88,203,109,78,57,22,48,48,71,61,39,
                    53,59,93,46,55,21,31,17,22,9,7,31,7)
      )

# no strata the stemr object --------------------------------------------------
set.seed(12511)
strata <- NULL
compartments <- c("S", "E", "I", "R")
rates <- list(rate("(alpha + beta * I) * (S - effpop)", "S", "E", lumped = TRUE, incidence = T),
              rate("omega", "E", "I", incidence = T),
              rate("mu", "I", "R", incidence = T))
state_initializer <- list(stem_initializer(c(S = popsize-30, E = 15, I = 10, R = 5), fixed = F,
                                           prior = c(popsize-30, 15, 10, 5)))
adjacency <- NULL
tcovar <- NULL
parameters = c(alpha = 0.25 / (popsize - 11.75e6), beta = 1.25 / (popsize - 11.75e6), omega = 1, mu = 1, rho = 0.75, phi = 10, effpop = 11.75e6)
constants <- c(t0 = 0)
t0 <- 0; tmax <- 73;

dynamics <-
      stem_dynamics(
            rates = rates,
            tmax = tmax,
            timestep = NULL,
            parameters = parameters,
            state_initializer = state_initializer,
            compartments = compartments,
            constants = constants,
            strata = strata,
            adjacency = adjacency,
            tcovar = tcovar,
            messages = T,
            compile_ode = T,
            compile_rates = T,
            compile_lna = T,
            rtol = 1e-6,
            atol = 1e-6,
            step_size = 1e-6
      )

emissions <- list(emission("cases", "negbinomial", c("phi","E2I * rho"), incidence = TRUE, obstimes = seq(1, tmax, by =1)))

measurement_process <- stem_measure(data = dat, emissions = emissions, dynamics = dynamics, messages = T)

stem_object <- stem(dynamics = dynamics, measurement_process = measurement_process)

#### initialize the inference
# {alpha, beta, omega, mu, rho, phi, Neff} 
# -> {log(Rext), log(Reff-1) + log(rho * Neff), log(omega / mu), log(1/mu), logit(rho), log(phi), log(rho * Neff)}
to_estimation_scale <- 
    function (params_nat) {
        
        l_effpop     <- log(popsize - params_nat[7])
        l_Neff_x_rho <- l_effpop + log(params_nat[5]) 
        l_infecdur   <- -log(params_nat[4])
        l_Reff_m1    <- log(exp(log(params_nat[2]) + l_effpop + l_infecdur)-1)
        
        return(
            c(
                log(params_nat[1]) + log(1000),
                l_Reff_m1 + l_Neff_x_rho,
                log(params_nat[3]) + l_infecdur, 
                l_infecdur,
                logit(params_nat[5]),
                log(params_nat[6]),
                l_Neff_x_rho
            )
        )
    }

from_estimation_scale <- 
    function (params_est) {
        
        rho <- expit(params_est[5])
        l_effpop <- params_est[7] - log(rho)
        
        return(c(
            exp(params_est[1] - log(1000)),
            exp(log(exp(params_est[2] - params_est[7])+1) - l_effpop - params_est[4]),
            exp(params_est[3] - params_est[4]),
            exp(-params_est[4]),
            rho,
            exp(params_est[6]),
            popsize - exp(l_effpop)
        ))
    }

## Priors
priors <- list(prior_density =
                   function(params_nat, params_est) {
                       
                       l_effpop <- params_est[7] - log(expit(params_est[5]))
                       l_Reff_m_1 <- params_est[2] - params_est[7]
                       
                       sum(
                           dexp(exp(params_est[1]), 40, log = TRUE) + params_est[1], 
                           dnorm(l_Reff_m_1, log(0.5), 1.08, log = TRUE),
                           dnorm(params_est[3], 0, 0.3, log = TRUE),
                           dnorm(params_est[4], 0, 0.3, log = TRUE),
                           dnorm(params_est[5], 0.85, 0.75, log = TRUE),
                           dexp(exp(params_est[6]), 0.69, log = TRUE) + params_est[6],
                           dnorm(l_effpop, 9.6, 0.62, log = TRUE)
                       )
                   },
               to_estimation_scale   = to_estimation_scale,
               from_estimation_scale = from_estimation_scale)

covmat_names <- c(
    "log_Reff_ext",
    "log_Reff_m_1_o",
    "log_omega_d_mu",
    "log_carriage_dur",
    "logit_rho",
    "log_phi",
    "log_effpop_o"
)
covmat <- diag(0.01, length(parameters))
rownames(covmat) <- colnames(covmat) <- covmat_names

mcmc_kernel <-
    kernel(
        method = "mvnss",
        sigma = covmat,
        scale_constant = 0.5,
        scale_cooling = 0.7,
        stop_adaptation = 5e4,
        step_size = 0.5,
        nugget = 1e-5, 
        harss_warmup = 0,
        mvnss_setting_list = 
            mvnss_settings(n_mvnss_updates = 1, 
                           initial_bracket_width = 0.5,
                           bracket_limits = c(0.001, 5),
                           nugget_cooling = 0.99, 
                           nugget_step_size = 0.001),
        messages = FALSE
    )

stem_object$dynamics$parameters <- function() {
    setNames(from_estimation_scale(to_estimation_scale(parameters) + rnorm(length(parameters), 0, 0.1)),
             names(parameters))
}

registerDoParallel(3)
# 
# results <- foreach(chain = 1:5,
#                    .packages=c("stemr"),
#                    .options.RNG = 52787,
#                    .export = ls(all.names = T)) %dorng% {
#                        
#                        chain_res <- stem_inference(stem_object = stem_object,
#                                                    method = "ode",
#                                                    iterations = 1.5e5,
#                                                    thin_params = 50,
#                                                    thin_latent_proc = 50,
#                                                    initialization_attempts = 500,
#                                                    priors = priors,
#                                                    mcmc_kernel = mcmc_kernel,
#                                                    t0_kernel = t0_kernel,
#                                                    ess_args = ess_settings(n_ess_updates = 1,
#                                                                            ess_warmup = 100, 
#                                                                            initdist_bracket_width = 2*pi,
#                                                                            initdist_bracket_update = 5e3,
#                                                                            lna_bracket_width = 2*pi,
#                                                                            lna_bracket_update = 5e3,
#                                                                            joint_strata_update = FALSE),
#                                                    print_progress = 5e3,
#                                                    messages = F)
#                        return(chain_res)
#                    }
# 
# save(results, file = paste0("guinea_ODE.Rdata"))
# 
# # grab the initial covariance matrix
# covs <- sapply(results, function(x) cov(x$stem_settings$mcmc_kernel$sigma))
# covmat <- matrix(rowMeans(covs), length(parameters), length(parameters), dimnames = list(covmat_names, covmat_names))
# 
# mcmc_kernel <-
#     kernel(
#         method = "mvnss",
#         sigma = covmat,
#         scale_constant = 0.5,
#         scale_cooling = 0.7,
#         stop_adaptation = 5e4,
#         step_size = 0.5,
#         nugget = 1e-5, 
#         harss_warmup = 0,
#         mvnss_setting_list = 
#             mvnss_settings(n_mvnss_updates = 1, 
#                            initial_bracket_width = 0.5,
#                            bracket_limits = c(0.001, 5),
#                            nugget_cooling = 0.99, 
#                            nugget_step_size = 0.001),
#         messages = FALSE
#     )
# 
# # grab the initial parameters and compartment volumes
# init_pars <- lapply(results, function(x) x$dynamics$parameters)
# init_states <- lapply(results, function(x) x$dynamics$initdist_params)

# rm(results)

# registerDoParallel(5)

results <- foreach(chain = 1:5,
                   .packages="stemr",
                   .options.RNG = 52787,
                   .export = ls(all.names = T)) %dorng% {
                       
                       # stem_object$dynamics$parameters <- init_pars[[chain]]
                       # stem_object$dynamics$initdist_params <- init_states[[chain]]
                       
                       chain_res <- stem_inference(stem_object = stem_object,
                                                   method = "lna",
                                                   iterations = 1.5e5,
                                                   thin_params = 50,
                                                   thin_latent_proc = 50,
                                                   initialization_attempts = 500,
                                                   priors = priors,
                                                   mcmc_kernel = mcmc_kernel,
                                                   t0_kernel = t0_kernel,
                                                   ess_args = ess_settings(n_ess_updates = 1,
                                                                           ess_warmup = 100, 
                                                                           initdist_bracket_width = 2*pi,
                                                                           initdist_bracket_update = 5e3,
                                                                           lna_bracket_width = 2*pi,
                                                                           lna_bracket_update = 5e3,
                                                                           joint_strata_update = FALSE),        
                                                   messages = F,
                                                   print_progress = 5e3)
                       return(chain_res)
                   }

save(results, file = paste0("guinea_LNA.Rdata"))
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.