R/simulate.R

Defines functions generate_draws.rt_optimised generate_draws.default generate_draws convert_pars.list generate_parameters

Documented in generate_draws generate_draws.default generate_draws.rt_optimised generate_parameters

#' Generate parameter draws from a pmcmc run
#' Code take from squire:::sample_pmcmc()
#' @param out Output of `squire::pmcmc`
#' @param draws Number of draws from mcmc chain. Default = 10
#' @export
generate_parameters <- function(out, draws = 10){
  #set up parameters
  pmcmc_results <- out$pmcmc_results
  n_trajectories <- draws
  burnin <- ceiling(out$pmcmc_results$inputs$n_mcmc/10)
  if("chains" %in% names(out$pmcmc_results)) {
    n_chains <- length(out$pmcmc_results$chains)
  } else {
    n_chains <- 1
  }
  n_particles <- 2
  forecast_days <- 0

  #code from squire: Will need updating if squire undergoes changes
  squire:::assert_pos_int(n_chains)
  if (n_chains == 1) {
    squire:::assert_custom_class(pmcmc_results, "squire_pmcmc")
  } else {
    squire:::assert_custom_class(pmcmc_results, "squire_pmcmc_list")
  }
  squire:::assert_pos_int(burnin)
  squire:::assert_pos_int(n_trajectories)
  squire:::assert_pos_int(n_particles)
  squire:::assert_pos_int(forecast_days)
  if (n_chains > 1) {
    res <- squire::create_master_chain(x = pmcmc_results, burn_in = burnin)
  } else if (n_chains == 1 & burnin > 0) {
    res <- pmcmc_results$results[-seq_along(burnin), ]
  } else {
    res <- pmcmc_results$results
  }
  res <- unique(res)
  probs <- exp(res$log_posterior)
  probs <- probs/sum(probs)
  drop <- 0.9
  while (any(is.na(probs))) {
    probs <- exp(res$log_posterior * drop)
    probs <- probs/sum(probs)
    drop <- drop^2
  }
  params_smpl <- sample(x = length(probs), size = n_trajectories,
                        replace = TRUE, prob = probs)
  params_smpl <- res[params_smpl, !grepl("log", colnames(res))]
  #add adjustment for if using weekly data
  params_smpl$start_date <- squire:::offset_to_start_date(get_data_start_date(out),
                                                          round(params_smpl$start_date))
  #return the parameters
  return(convert_pars.list(params_smpl))
}

#' Convert a parameter sample (i.e. $replicate_parameters) into a pars.list
#' format. Own function as reused in scan fitting
#' @noRd
convert_pars.list <- function(params_smpl){
  pars.list <- split(params_smpl, 1:nrow(params_smpl))
  names(pars.list) <- rep("pars", length(pars.list))
  return(pars.list)
}

#' Generate Draws from a model fit.
#'
#' Appends the simulated values to \code{$output}.
#'
#' @param out Output from a fitted MCMC or Rt Optimise
#' @param ... method specific arguments.
#' @export
generate_draws <- function(out, ...){
  UseMethod("generate_draws")
}
#' Generate Draws from pmcmc run
#'
#' Uses furrr, so can be called in parallel.
#'
#' @param out Output of `squire::pmcmc`
#' @param pars.list Output of generate_parameters(), default = NULL calls
#' generate_parameters.
#' @param draws Number of draws from mcmc chain. Default = 10, if NULL then uses
#' replicate_parameters as the parameters
#' @param ... method specific arguments, unused
#' @export
generate_draws.default <- function(out, pars.list, draws, ...) {
  #generate parameters if needed
  if(is.null(pars.list)){
    if(is.null(draws)){
      pars.list <- convert_pars.list(out$replicate_parameters)
      draws <- length(pars.list)
    } else {
      pars.list <- generate_parameters(out = out, draws = draws)
    }
  }

  # handle for no death days
  if(!("pmcmc_results" %in% names(out))) {
    message("`out` was not generated by pmcmc as no deaths for this country. \n",
            "Returning the original object, which assumes epidemic seeded on date ",
            "fits were run")
    return(out)
  }

  # grab information from the pmcmc run
  pmcmc <- out$pmcmc_results
  squire_model <- out$pmcmc_results$inputs$squire_model
  country <- out$parameters$country
  population <- out$parameters$population
  data <- out$pmcmc_results$inputs$data

  interventions <- out$interventions

  #--------------------------------------------------------
  # Section 3 of pMCMC Wrapper: Sample PMCMC Results
  #--------------------------------------------------------
  #rename objects to their sample_pmcmc equivalent (so that it is simple to update
  #this code)
  pmcmc_results <- pmcmc
  n_particles <- 2
  forecast_days <- 0
  log_likelihood <- get_model_likelihood(out)
  replicates <- draws
  #recreate params_smpl object
  params_smpl <-
    do.call(
      rbind,
      pars.list
    )

  #instead of using squire:::sample_pmcmc we use the pars.list values provided
  #the following code is taken from squire:::sample_pmcmc and will need updating
  #if squire undergoes major changes
  message("Sampling from pMCMC Posterior...")
  if (Sys.getenv("SQUIRE_PARALLEL_DEBUG") == "TRUE") {
    traces <- purrr::map(.x = pars.list, .f = log_likelihood,
                         data = pmcmc_results$inputs$data, squire_model = pmcmc_results$inputs$squire_model,
                         model_params = pmcmc_results$inputs$model_params,
                         pars_obs = pmcmc_results$inputs$pars_obs, n_particles = n_particles,
                         forecast_days = forecast_days, interventions = pmcmc_results$inputs$interventions,
                         Rt_args = pmcmc_results$inputs$Rt_args, return = "full")
  } else{
    traces <- furrr::future_map(.x = pars.list, .f = log_likelihood,
                                data = pmcmc_results$inputs$data, squire_model = pmcmc_results$inputs$squire_model,
                                model_params = pmcmc_results$inputs$model_params,
                                pars_obs = pmcmc_results$inputs$pars_obs, n_particles = n_particles,
                                forecast_days = forecast_days, interventions = pmcmc_results$inputs$interventions,
                                Rt_args = pmcmc_results$inputs$Rt_args, return = "full",
                                .progress = TRUE, .options = furrr::furrr_options(seed = NULL))
  }
  num_rows <- unlist(lapply(traces, nrow))
  max_rows <- max(num_rows)
  seq_max <- seq_len(max_rows)
  max_date_names <- rownames(traces[[which.max(unlist(lapply(traces,
                                                             nrow)))]])
  trajectories <- array(NA, dim = c(max_rows, ncol(traces[[1]]),
                                    length(traces)), dimnames = list(max_date_names, colnames(traces[[1]]),
                                                                     NULL))
  for (i in seq_len(length(traces))) {
    trajectories[utils::tail(seq_max, nrow(traces[[i]])), , i] <- traces[[i]]
  }
  pmcmc_samples <- list(trajectories = trajectories, sampled_PMCMC_Results = params_smpl,
                        inputs = list(squire_model = pmcmc_results$inputs$squire_model,
                                      model_params = pmcmc_results$inputs$model_params,
                                      interventions = pmcmc_results$inputs$interventions,
                                      data = pmcmc_results$inputs$data, pars_obs = pmcmc_results$inputs$pars_obs))
  class(pmcmc_samples) <- "squire_sample_PMCMC"


  #--------------------------------------------------------
  # Section 4 of pMCMC Wrapper: Tidy Output
  #--------------------------------------------------------

  # create a fake run object and fill in the required elements
  r <- squire_model$run_func(country = country,
                             contact_matrix_set = pmcmc$inputs$model_params$contact_matrix_set,
                             tt_contact_matrix = pmcmc$inputs$model_params$tt_matrix,
                             hosp_bed_capacity = pmcmc$inputs$model_params$hosp_bed_capacity,
                             tt_hosp_beds = pmcmc$inputs$model_params$tt_hosp_beds,
                             ICU_bed_capacity = pmcmc$inputs$model_params$ICU_bed_capacity,
                             tt_ICU_beds = pmcmc$inputs$model_params$tt_ICU_beds,
                             population = population,
                             day_return = TRUE,
                             replicates = 1,
                             time_period = nrow(pmcmc_samples$trajectories))

  # and add the parameters that changed between each simulation, i.e. posterior draws
  r$replicate_parameters <- pmcmc_samples$sampled_PMCMC_Results

  # as well as adding the pmcmc chains so it's easy to draw from the chains again in the future
  r$pmcmc_results <- pmcmc

  # then let's create the output that we are going to use
  names(pmcmc_samples)[names(pmcmc_samples) == "trajectories"] <- "output"
  dimnames(pmcmc_samples$output) <- list(dimnames(pmcmc_samples$output)[[1]], dimnames(r$output)[[2]], NULL)
  r$output <- pmcmc_samples$output

  # and adjust the time as before
  full_row <- match(0, apply(r$output[,"time",],2,function(x) { sum(is.na(x)) }))
  saved_full <- r$output[,"time",full_row]
  for(i in seq_len(replicates)) {
    na_pos <- which(is.na(r$output[,"time",i]))
    full_to_place <- saved_full - which(rownames(r$output) == as.Date(get_data_end_date_inner(out))) + 1L
    if(length(na_pos) > 0) {
      full_to_place[na_pos] <- NA
    }
    r$output[,"time",i] <- full_to_place
  }

  # second let's recreate the output
  r$model <- pmcmc_samples$inputs$squire_model$odin_model(
    user = pmcmc_samples$inputs$model_params, unused_user_action = "ignore"
  )

  # we will add the interventions here so that we know what times are needed for projection
  r$interventions <- interventions

  # and fix the replicates
  r$parameters$replicates <- replicates
  r$parameters$time_period <- as.numeric(diff(as.Date(range(rownames(r$output)))))
  r$parameters$dt <- pmcmc$inputs$model_params$dt
  r$parameters$seeding_cases <- out$parameters$seeding_cases
  r$parameters$seed <- out$parameters$seed

  #assign the same class as before
  class(r) <- class(out)

  return(r)
}
#' Generate Draws from pmcmc run
#'
#' @param out Output of `squire::rt_optimised`
#' @param t_end Time to simulate up to
#' @param project_forwards Should we use the previous outputs as our start point
#' and treat this as projection, default = FALSE
#' @param ... method specific arguments, unused
#' @export
generate_draws.rt_optimised <- function(out, t_end = NULL, project_forwards = FALSE, ...){
  #for each sample
  if(is.null(t_end)){
    t_end <- max(out$inputs$data$t_end)
  }

  #format initial conditions
  if(project_forwards){
    #get final values
    outputs <- purrr::map(seq_len(dim(out$output)[3]), ~out$output[, , .x])
    #calculate start for new simulation
    t_start <- as.numeric(max(lubridate::as_date(dimnames(out$output)[[1]])) - out$inputs$start_date)
    out$output <- NULL
  } else {
    #remove outputs to save on memory with parallel
    outputs <- purrr::map(seq_along(out$samples), ~list())
    out$output <- NULL
    t_start <- 0
  }

  pmap_list <- list(
      sample = out$samples,
      output = outputs
  )
  rm(outputs)

  map_func <- function(sample, output, parameters, squire_model, t_start, t_end, project_fowards){
    #generate model function
    parameters <- append(parameters, sample)
    parameters$time_period <- t_end
    model <- generate_model_function(squire_model, parameters, use_difference = FALSE, dt = 1)
    if(project_forwards){
      #add initial condition
      initial_state <- setup_parameters(squire_model, parameters) %>%
        update_initial_state(output %>% utils::tail(1))
    } else {
      initial_state <- NULL
    }
    #run model
    sim <- model(Rt = sample$R0, tt_Rt = sample$tt_R0, t_start = t_start, t_end = t_end,
                          atol = 10^-6, rtol = 10^-6, initial_state = initial_state)
    if(project_forwards){
      #append to old output
      rbind(output[-dim(output)[1],], sim)
    } else {
      sim
    }
  }

  if (Sys.getenv("SQUIRE_PARALLEL_DEBUG") == "TRUE") {
    sims <- purrr::pmap(
      .l = pmap_list, .f = map_func, parameters = out$parameters, squire_model = out$squire_model,
      t_start = t_start, t_end = t_end, project_fowards = project_forwards
    )
  } else {
    sims <- furrr::future_pmap(
      .l = pmap_list, .f = map_func, parameters = out$parameters, squire_model = out$squire_model,
      t_start = t_start, t_end = t_end, project_fowards = project_forwards,
      .options = furrr::furrr_options(seed = NULL)
    )
  }
  rm(pmap_list)

  out$output <-
    #merge simulation outputs into one
    abind::abind(sims, along = 3, new.names = list(
      as.character(out$inputs$start_date + seq_len(nrow(sims[[1]])) - 1),
      colnames(sims[[1]]),
      NULL
    ))
  out
}
mrc-ide/squire.page documentation built on May 27, 2023, 11:20 a.m.