R/mcmc_utils.R

Defines functions run_deterministic_comparison_iran iran_log_likelihood generate_draws generate_parameters

Documented in generate_draws generate_parameters run_deterministic_comparison_iran

#' Generate parameter draws from a squire pmcmc run
#' @param out Output of [[squire::pmcmc]]
#' @param draws Number of draws from mcmc chain. Default = 10
generate_parameters <- function(out, draws = 10, burnin = 1000, ll = TRUE){
  #set up parameters
  pmcmc_results <- out$pmcmc_results
  n_trajectories <- draws
  if("chains" %in% names(out$pmcmc_results)) {
    n_chains <- length(out$pmcmc_results$chains)
  } else {
    n_chains <- 1
  }
  n_particles <- 2
  forecast_days <- 0

  #code from squire: Will need updating if squire undergoes changes
  squire:::assert_pos_int(n_chains)
  if (n_chains == 1) {
    squire:::assert_custom_class(pmcmc_results, "squire_pmcmc")
  } else {
    squire:::assert_custom_class(pmcmc_results, "squire_pmcmc_list")
  }

  squire:::assert_pos_int(burnin)
  squire:::assert_pos_int(n_trajectories)
  squire:::assert_pos_int(n_particles)
  squire:::assert_pos_int(forecast_days)

  if (n_chains > 1) {
    res <- squire::create_master_chain(x = pmcmc_results, burn_in = burnin)
  } else if (n_chains == 1 & burnin > 0) {
    res <- pmcmc_results$results[-seq_len(burnin), ]
  } else {
    res <- pmcmc_results$results
  }

  # are we drawing based on ll
  if (ll) {

    squire:::assert_neg(res$log_posterior, zero_allowed = FALSE)
    res <- unique(res)
    probs <- exp(res$log_posterior)
    probs <- probs/sum(probs)
    drop <- 0.9

    while (any(is.na(probs))) {
      probs <- exp(res$log_posterior * drop)
      probs <- probs/sum(probs)
      drop <- drop^2
    }

    params_smpl <- sample(x = length(probs), size = n_trajectories,
                          replace = TRUE, prob = probs)

  } else {

    params_smpl <- sample(x = nrow(res), size = n_trajectories, replace = FALSE)

  }

  params_smpl <- res[params_smpl, !grepl("log", colnames(res))]
  params_smpl$start_date <- squire:::offset_to_start_date(pmcmc_results$inputs$data$date[1],
                                                          round(params_smpl$start_date))
  pars_list <- split(params_smpl, 1:nrow(params_smpl))
  names(pars_list) <- rep("pars", length(pars_list))
  #return the parameters
  return(pars_list)
}

#' Generate draws using parameters drawn from posterior
#' @param out Output of [[squire::pmcmc]]
#' @param pars_list Output of [[generate_parameters]]
#' @param parallel Are we simulating in parallel. Default = FALSE
#' @param draws How many draws are being used from pars_list. Default = NULL,
#'   which will use all the pars.
#' @param interventions Are new interventions being used or default. Default = NULL
generate_draws <- function(out, pars_list, parallel = FALSE,
                           draws = NULL, interventions = NULL, ...){

  # handle for no death days
  if(!("pmcmc_results" %in% names(out))) {
    message("`out` was not generated by pmcmc as no deaths for this country. \n",
            "Returning the original object, which assumes epidemic seeded on date ",
            "fits were run")
    return(out)
  }

  # grab information from the pmcmc run
  pmcmc <- out$pmcmc_results
  squire_model <- out$pmcmc_results$inputs$squire_model
  country <- out$parameters$country
  population <- out$parameters$population
  data <- out$pmcmc_results$inputs$data

  # are we drawing in parallel
  if (parallel) {
    suppressWarnings(future::plan(future::multisession()))
  }

  if(!is.null(interventions)){
    #if making a change add that intervention here
    pmcmc$inputs$interventions <- interventions
  }else{
    #else this is the interventions that come with the object
    interventions <- out$interventions
  }

  if (is.null(draws)) {
    draws <- length(pars_list)
  }

  #--------------------------------------------------------
  # Section 3 of pMCMC Wrapper: Sample PMCMC Results
  #--------------------------------------------------------
  #rename objects to their sample_pmcmc equivalent (so that it is simple to update
  #this code)
  pmcmc_results <- pmcmc
  n_particles <- 2
  forecast_days <- 0
  log_likelihood <- squire:::calc_loglikelihood
  replicates <- draws
  #recreate params_smpl object
  params_smpl <- do.call(rbind, pars_list)

  #instead of using squire:::sample_pmcmc we use the pars_list values provided
  #the following code is taken from squire:::sample_pmcmc and will need updating
  #if squire undergoes major changes
  message("Sampling from pMCMC Posterior...")
  if (Sys.getenv("SQUIRE_PARALLEL_DEBUG") == "TRUE") {
    traces <- purrr::map(.x = pars_list, .f = iran_log_likelihood,
                         data = pmcmc_results$inputs$data, squire_model = pmcmc_results$inputs$squire_model,
                         model_params = pmcmc_results$inputs$model_params,
                         pars_obs = pmcmc_results$inputs$pars_obs, n_particles = n_particles,
                         forecast_days = forecast_days, interventions = pmcmc_results$inputs$interventions,
                         Rt_args = pmcmc_results$inputs$Rt_args, return = "full", ...)
  } else{
    traces <- furrr::future_map(.x = pars_list, .f = iran_log_likelihood,
                                data = pmcmc_results$inputs$data, squire_model = pmcmc_results$inputs$squire_model,
                                model_params = pmcmc_results$inputs$model_params,
                                pars_obs = pmcmc_results$inputs$pars_obs, n_particles = n_particles,
                                forecast_days = forecast_days, interventions = pmcmc_results$inputs$interventions,
                                Rt_args = pmcmc_results$inputs$Rt_args, return = "full", ...,
                                .progress = TRUE, .options = furrr::furrr_options(seed = NULL))
  }
  num_rows <- unlist(lapply(traces, nrow))
  max_rows <- max(num_rows)
  seq_max <- seq_len(max_rows)
  max_date_names <- rownames(traces[[which.max(unlist(lapply(traces,
                                                             nrow)))]])
  trajectories <- array(NA, dim = c(max_rows, ncol(traces[[1]]),
                                    length(traces)), dimnames = list(max_date_names, colnames(traces[[1]]),
                                                                     NULL))
  for (i in seq_len(length(traces))) {
    trajectories[utils::tail(seq_max, nrow(traces[[i]])), , i] <- traces[[i]]
  }
  pmcmc_samples <- list(trajectories = trajectories, sampled_PMCMC_Results = params_smpl,
                        inputs = list(squire_model = pmcmc_results$inputs$squire_model,
                                      model_params = pmcmc_results$inputs$model_params,
                                      interventions = pmcmc_results$inputs$interventions,
                                      data = pmcmc_results$inputs$data, pars_obs = pmcmc_results$inputs$pars_obs))
  class(pmcmc_samples) <- "squire_sample_PMCMC"


  #--------------------------------------------------------
  # Section 4 of pMCMC Wrapper: Tidy Output
  #--------------------------------------------------------

  # create a fake run object and fill in the required elements
  r <- squire_model$run_func(country = country,
                             contact_matrix_set = pmcmc$inputs$model_params$contact_matrix_set,
                             tt_contact_matrix = pmcmc$inputs$model_params$tt_matrix,
                             hosp_bed_capacity = pmcmc$inputs$model_params$hosp_bed_capacity,
                             tt_hosp_beds = pmcmc$inputs$model_params$tt_hosp_beds,
                             ICU_bed_capacity = pmcmc$inputs$model_params$ICU_bed_capacity,
                             tt_ICU_beds = pmcmc$inputs$model_params$tt_ICU_beds,
                             population = population,
                             day_return = TRUE,
                             replicates = 1,
                             time_period = nrow(pmcmc_samples$trajectories))

  # and add the parameters that changed between each simulation, i.e. posterior draws
  r$replicate_parameters <- pmcmc_samples$sampled_PMCMC_Results

  # as well as adding the pmcmc chains so it's easy to draw from the chains again in the future
  r$pmcmc_results <- pmcmc

  # then let's create the output that we are going to use
  names(pmcmc_samples)[names(pmcmc_samples) == "trajectories"] <- "output"
  dimnames(pmcmc_samples$output) <- list(dimnames(pmcmc_samples$output)[[1]], dimnames(r$output)[[2]], NULL)
  r$output <- pmcmc_samples$output

  # and adjust the time as before
  full_row <- match(0, apply(r$output[,"time",],2,function(x) { sum(is.na(x)) }))
  saved_full <- r$output[,"time",full_row]
  for(i in seq_len(replicates)) {
    na_pos <- which(is.na(r$output[,"time",i]))
    full_to_place <- saved_full - which(rownames(r$output) == as.Date(max(data$date))) + 1L
    if(length(na_pos) > 0) {
      full_to_place[na_pos] <- NA
    }
    r$output[,"time",i] <- full_to_place
  }

  # second let's recreate the output
  r$model <- pmcmc_samples$inputs$squire_model$odin_model(
    user = pmcmc_samples$inputs$model_params, unused_user_action = "ignore"
  )

  # we will add the interventions here so that we know what times are needed for projection
  r$interventions <- interventions

  # and fix the replicates
  r$parameters$replicates <- replicates
  r$parameters$time_period <- as.numeric(diff(as.Date(range(rownames(r$output)))))
  r$parameters$dt <- pmcmc$inputs$model_params$dt

  if ("province" %in% names(out$parameters)) {
    r$parameters$province <- out$parameters$province
  }

  return(r)
}


#' Specific log_likelihood wrapper for Iran simulations
#' @noRd
iran_log_likelihood <- function(pars, data, squire_model, model_params, pars_obs, n_particles,
                                forecast_days = 0, return = "ll", Rt_args, interventions, rt_mult = 1) {
  switch(return, full = {
    save_particles <- TRUE
    full_output <- TRUE
    pf_return <- "sample"
  }, ll = {
    save_particles <- FALSE
    forecast_days <- 0
    full_output <- FALSE
    pf_return <- "single"
  }, {
    stop("Unknown return type to calc_loglikelihood")
  })
  squire:::assert_in(c("R0", "start_date"), names(pars), message = "Must specify R0, start date to infer")
  R0 <- pars[["R0"]]
  start_date <- pars[["start_date"]]
  squire:::assert_pos(R0)
  squire:::assert_date(start_date)
  R0_change <- interventions$R0_change
  date_R0_change <- interventions$date_R0_change
  date_contact_matrix_set_change <- interventions$date_contact_matrix_set_change
  date_ICU_bed_capacity_change <- interventions$date_ICU_bed_capacity_change
  date_hosp_bed_capacity_change <- interventions$date_hosp_bed_capacity_change
  date_vaccine_change <- interventions$date_vaccine_change
  date_vaccine_efficacy_infection_change <- interventions$date_vaccine_efficacy_infection_change
  date_vaccine_efficacy_disease_change <- interventions$date_vaccine_efficacy_disease_change
  if (is.null(date_R0_change)) {
    tt_beta <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_R0_change,
                                                    change = R0_change, start_date = start_date, steps_per_day = round(1/model_params$dt),
                                                    starting_change = 1)
    model_params$tt_beta <- tt_list$tt
    R0_change <- tt_list$change
    date_R0_change <- tt_list$dates
  }
  if (is.null(date_contact_matrix_set_change)) {
    tt_contact_matrix <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_contact_matrix_set_change,
                                                    change = seq_along(interventions$contact_matrix_set)[-1],
                                                    start_date = start_date, steps_per_day = round(1/model_params$dt),
                                                    starting_change = 1)
    model_params$tt_matrix <- tt_list$tt
    model_params$mix_mat_set <- model_params$mix_mat_set[tt_list$change,
                                                         , ]
  }
  if (is.null(date_ICU_bed_capacity_change)) {
    tt_ICU_beds <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_ICU_bed_capacity_change,
                                                    change = interventions$ICU_bed_capacity[-1], start_date = start_date,
                                                    steps_per_day = round(1/model_params$dt), starting_change = interventions$ICU_bed_capacity[1])
    model_params$tt_ICU_beds <- tt_list$tt
    model_params$ICU_beds <- tt_list$change
  }
  if (is.null(date_hosp_bed_capacity_change)) {
    tt_hosp_beds <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_hosp_bed_capacity_change,
                                                    change = interventions$hosp_bed_capacity[-1], start_date = start_date,
                                                    steps_per_day = round(1/model_params$dt), starting_change = interventions$hosp_bed_capacity[1])
    model_params$tt_hosp_beds <- tt_list$tt
    model_params$hosp_beds <- tt_list$change
  }
  if (is.null(date_vaccine_change)) {
    tt_vaccine <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_vaccine_change,
                                                    change = interventions$max_vaccine[-1], start_date = start_date,
                                                    steps_per_day = round(1/model_params$dt), starting_change = interventions$max_vaccine[1])
    model_params$tt_vaccine <- tt_list$tt
    model_params$max_vaccine <- tt_list$change
  }
  if (is.null(date_vaccine_efficacy_infection_change)) {
    tt_vaccine_efficacy_infection <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_vaccine_efficacy_infection_change,
                                                    change = seq_along(interventions$vaccine_efficacy_infection)[-1],
                                                    start_date = start_date, steps_per_day = round(1/model_params$dt),
                                                    starting_change = 1)
    model_params$tt_vaccine_efficacy_infection <- tt_list$tt
    model_params$vaccine_efficacy_infection <- model_params$vaccine_efficacy_infection[tt_list$change,
                                                                                       , ]
  }
  if (is.null(date_vaccine_efficacy_disease_change)) {
    tt_vaccine_efficacy_disease <- 0
  }
  else {
    tt_list <- squire:::intervention_dates_for_odin(dates = date_vaccine_efficacy_disease_change,
                                                    change = seq_along(interventions$vaccine_efficacy_disease)[-1],
                                                    start_date = start_date, steps_per_day = round(1/model_params$dt),
                                                    starting_change = 1)
    model_params$tt_vaccine_efficacy_disease <- tt_list$tt
    model_params$prob_hosp <- model_params$prob_hosp[tt_list$change,
                                                     , ]
  }
  R0 <- squire:::evaluate_Rt_pmcmc(R0_change = R0_change, R0 = R0, date_R0_change = date_R0_change,
                                   pars = pars, Rt_args = Rt_args)
  R0 <- R0*rt_mult
  beta_set <- squire:::beta_est(squire_model = squire_model, model_params = model_params,
                                R0 = R0)
  model_params$beta_set <- beta_set
  if (inherits(squire_model, "stochastic")) {
    pf_result <- squire:::run_particle_filter(data = data, squire_model = squire_model,
                                              model_params = model_params, model_start_date = start_date,
                                              obs_params = pars_obs, n_particles = n_particles,
                                              forecast_days = forecast_days, save_particles = save_particles,
                                              full_output = full_output, return = pf_return)
  }
  else if (inherits(squire_model, "deterministic")) {
    pf_result <- run_deterministic_comparison_iran(data = data,
                                                   squire_model = squire_model, model_params = model_params,
                                                   model_start_date = start_date, obs_params = pars_obs,
                                                   forecast_days = forecast_days, save_history = save_particles,
                                                   return = pf_return)
  }
  pf_result

}


ll_pois <- function (data, model, phi, k, exp_noise) {
  mu <- phi * model + rexp(length(model), rate = exp_noise)
  dpois(data, lambda = mu, log = TRUE)
}

#' Specific deterministic model run for Iran with timing of Delta included
#' @inheritParams squire:::run_deterministic_comparison
run_deterministic_comparison_iran <- function(data, squire_model, model_params, model_start_date = "2020-02-02",
                                              obs_params = list(
                                                phi_cases = 0.1,
                                                k_cases = 2,
                                                phi_death = 1,
                                                k_death = 2,
                                                exp_noise = 1e+06
                                              ), forecast_days = 0, save_history = FALSE,
                                              return = "ll") {

  if (!(return %in% c("full", "ll", "sample", "single"))) {
    stop("return argument must be full, ll, sample", "single")
  }
  if (as.Date(data$date[data$deaths > 0][1], "%Y-%m-%d") <
      as.Date(model_start_date, "%Y-%m-%d")) {
    stop("Model start date is later than data start date")
  }

  # set up as normal
  data <- squire:::particle_filter_data(data = data, start_date = model_start_date,
                                        steps_per_day = round(1/model_params$dt))
  # correct for weekly deaths
  data$day_end[nrow(data)] <- data$day_start[nrow(data)] + 7
  data$step_end[nrow(data)] <- data$step_start[nrow(data)] + 7

  # back to normal
  model_params$tt_beta <- round(model_params$tt_beta * model_params$dt)
  model_params$tt_contact_matrix <- round(model_params$tt_contact_matrix *
                                            model_params$dt)
  model_params$tt_hosp_beds <- round(model_params$tt_hosp_beds *
                                       model_params$dt)
  model_params$tt_ICU_beds <- round(model_params$tt_ICU_beds *
                                      model_params$dt)

  # steps as normal
  steps <- c(0, data$day_end)
  fore_steps <- seq(data$day_end[nrow(data)], length.out = forecast_days + 1L)
  steps <- unique(c(steps, fore_steps))

  if("dur_R" %in% names(obs_params)) {
    if(obs_params$dur_R != 365) {
      ch_dur_R <- as.integer(as.Date("2021-05-01") - model_start_date)
      model_params$tt_dur_R <- c(0, ch_dur_R, ch_dur_R+60)
      model_params$gamma_R <- c(model_params$gamma_R, 2/obs_params$dur_R, model_params$gamma_R)
    }
  }

  if("prob_hosp_multiplier" %in% names(obs_params)) {
    if(obs_params$prob_hosp_multiplier != 1) {
      ch_dur_R <- as.integer(as.Date("2021-05-01") - model_start_date)
      model_params$tt_prob_hosp_multiplier <- c(0, ch_dur_R)
      model_params$prob_hosp_multiplier <- c(model_params$prob_hosp_multiplier, obs_params$prob_hosp_multiplier)
    }
  }

  # run model
  model_func <- squire_model$odin_model(user = model_params,
                                        unused_user_action = "ignore")
  out <- model_func$run(t = seq(0, tail(steps, 1), 1), atol = 1e-6, rtol = 1e-6)
  index <- squire:::odin_index(model_func)

  # get deaths for comparison
  Ds <- diff(rowSums(out[c(data$day_end[2]-7, data$day_end[-1]), index$D]))
  Ds[Ds < 0] <- 0
  deaths <- data$deaths[-1]

  # what type of ll for deaths
  if (obs_params$treated_deaths_only) {
    Ds_heathcare <- diff(rowSums(out[, index$D_get]))
    Ds_heathcare <- Ds_heathcare[data$day_end[-1]]
    ll <- ll_pois(deaths, Ds_heathcare, obs_params$phi_death,
                  obs_params$k_death, obs_params$exp_noise)
  }
  else {
    ll <- ll_pois(deaths, Ds, obs_params$phi_death, obs_params$k_death,
                  obs_params$exp_noise)
  }

  # now the ll for the seroprevalence
  sero_df <- obs_params$sero_df
  lls <- 0
  if(!is.null(sero_df)) {
    if(nrow(sero_df) > 0) {

      sero_at_date <- function(date, symptoms, det, dates, N) {

        di <- which(dates == date)
        to_sum <- tail(symptoms[seq_len(di)], length(det))
        min(sum(rev(to_sum)*head(det, length(to_sum)), na.rm=TRUE)/N, 0.99)

      }

      # get symptom incidence
      symptoms <- rowSums(out[,index$E2]) * model_params$gamma_E

      # dates of incidence, pop size and dates of sero surveys
      dates <- data$date[[1]] + seq_len(nrow(out)) - 1L
      N <- sum(model_params$population)
      sero_dates <- list(sero_df$date_end, sero_df$date_start, sero_df$date_start + as.integer((sero_df$date_end - sero_df$date_start)/2))
      unq_sero_dates <- unique(c(sero_df$date_end, sero_df$date_start, sero_df$date_start + as.integer((sero_df$date_end - sero_df$date_start)/2)))
      det <- obs_params$sero_det

      # estimate model seroprev
      sero_model <- vapply(unq_sero_dates, sero_at_date, numeric(1), symptoms, det, dates, N)
      sero_model_mat <- do.call(cbind,lapply(sero_dates, function(x) {sero_model[match(x, unq_sero_dates)]}))

      # likelihood of model obvs
      lls <- rowMeans(dbinom(sero_df$sero_pos, sero_df$samples, sero_model_mat, log = TRUE))

    }
  }

  # and wrap up as normal
  date <- data$date[[1]] + seq_len(nrow(out)) - 1L
  rownames(out) <- as.character(date)
  attr(out, "date") <- date
  pf_results <- list()
  pf_results$log_likelihood <- sum(ll) + sum(lls)
  if (save_history) {
    pf_results$states <- out
  }
  else if (return == "single") {
    pf_results$sample_state <- out[nrow(out), ]
  }
  if (return == "ll") {
    ret <- pf_results$log_likelihood
  }
  else if (return == "sample") {
    ret <- pf_results$states
  }
  else if (return == "single" || return == "full") {
    ret <- pf_results
  }
  ret
}
OJWatson/iran-ascertainment documentation built on April 24, 2022, 10:09 p.m.