scripts/13_fit_R0_model.R

devtools::load_all()

library(rstan)
library(dplyr)
library(stringr)
library(reshape2)



# define parameters -----------------------------------------------------------


T_inf <- 6 # from Zika paper(T_inf = 1/human recovery rate)

in_path <- file.path("output", "central_estimates_from_Mordecai")

params_to_estimate <- c("c_a", "T0_a", "Tm_a")#, "c_b",
                        #"T0_b", "Tm_b", "c_c", "T0_c",
                        #"Tm_c", "c_lf", "T0_lf", "Tm_lf",
                        #"c_PDR", "T0_PDR", "Tm_PDR")

indicator_vars <- c("aMin_indicator", "aMax_indicator")

out_dir <- file.path("figures", "fit_one_trait")


# load data -------------------------------------------------------------------


foi_covariates <- readRDS(file.path("output", "foi_data_cov_rescaled.rds"))

m_values <- readRDS(file.path(in_path, "m_mean_DayTemp_const_term.rds"))


# run -------------------------------------------------------------------------


R0_values <- foi_covariates$R0_1

R0_values_n <- length(R0_values)

temp_values <- foi_covariates$DayTemp_const_term

stan_dat <- list(J = R0_values_n,
                 y = R0_values,
                 temp = temp_values,
                 m = m_values,
                 T_inf = T_inf)

options(mc.cores = 7)

fit_normal <- stan(file = "R0_model.stan",
                   data = stan_dat,
                   iter = 20000,
                   chains = 4,
                   seed = 132532525,
                   control = list(adapt_delta = 0.99))


# post processing -------------------------------------------------------------


list_of_draws <- extract(fit_normal)

print(fit_normal, probs = c(0.25, 0.5, 0.75))

p <- traceplot(fit_normal, params_to_estimate)

ppd_data_sample <- sample(20000:40000, 12)

ppd_data <- t(list_of_draws$SimData[ppd_data_sample,]) %>%
  as.data.frame() %>%
  rename_all(list(name = ~str_replace(., "V", "Sample_"))) %>%
  mutate(id = seq_len(length(R0_values))) %>%
  melt(id.vars = "id",
       variable.name = "sample")
p2 <- ggplot(data = ppd_data, mapping = aes(x = value)) +
  geom_histogram(col = "white") +
  facet_wrap(~ sample, ncol = 3) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  scale_x_continuous("predicted R0") +
  scale_y_continuous("Frequency")

save_plot(p, out_dir, "trace_plot", wdt = 20, hgt = 10)
save_plot(p2, out_dir, "ppchecks", wdt = 15, hgt = 15)

png(file = file.path(out_dir, "pairs_plot.png"),
    width = 20,
    height = 15,
    units = "cm",
    pointsize = 12,
    res = 300)
pairs(fit_normal, pars = c(params_to_estimate, "lp__"), las = 1) # below the diagonal
dev.off()
lorecatta/DENVclimate documentation built on Dec. 11, 2019, 7:05 a.m.