R/scenario.R

Defines functions compare_scenarios run_scenario valid_pars

Documented in compare_scenarios run_scenario

valid_pars <- function(pars) {
  nt <- length(pars$tt)
  all(sapply(pars, length) == nt)
}


#' Run scenario analysis
#'
#' @param sim baseline simulations
#' @param fn a function for modifying parameters
#'
#' @return
#' @export
#'
#' @examples
#' library(R2jags)
#' cases <- as_bass_data(n_covid19$Hubei)
#' est <- BassSIR::fit(cases, r_rec = 1/22.2, r_death = 1/22.3, type = "BassSIR")
#'
#' sim <- simulate(est, nsim = 500)
#'
#' zero_kappa <- function(pars) {
#'   pars$kappa <- rep(0, length(pars$kappa))
#'   return(pars)
#' }
#'
#' lockdown <- run_scenario(sim, zero_kappa)
#'
#' cp <- compare_scenarios(sim, Lockdown = lockdown)
run_scenario <- function(sim, fn) {

  f <- switch(sim$ModelType,
              BassSIR = system.file("sim/IntvBassSIR.txt", package = "BassSIR"),
              SIR = system.file("sim/IntvSIR.txt", package = "BassSIR"))

  y0s <- sim$Y0s

  get_pars <- function(i) {
    p <- sim$Parameters[i, ]
    nt <- length(sim$Date)

    list(
      tt = 1:nt - 1,
      m = rep(p$m, nt),
      beta = rep(p$beta, nt),
      kappa = rep(p$kappa, nt),
      r_rec = rep(as.numeric(sim$Offsets[1]), nt),
      r_death = rep(as.numeric(sim$Offsets[2]), nt)
    )
  }

  pars <- get_pars(1)
  times <- seq(min(pars$tt), max(pars$tt), 0.5)

  pars <- fn(pars)
  stopifnot(valid_pars(pars))

  model <- odin::odin(f)
  cm <- model(user = pars)
  test <- cm$run(times)
  test <- test[times == round(times), ]

  sims <- array(0, c(dim(test), nrow(sim$Parameters)))

  for (i in 1:nrow(sim$Parameters)) {
    pars <- fn(get_pars(i))
    pars$I0 <- sim$Y0s$I0[i]
    pars$A0 <- sim$Y0s$A0[i]

    cm$set_user(user = pars)
    ys <- cm$run(times)
    ys <- ys[times == round(times), ]
    sims[, , i] <- ys
  }

  dimnames(sims)[[2]] <- colnames(test)
  dimnames(sims)[[3]] <- dimnames(sim$Simulations)[[3]]

  res <- list(
    Simulations = sims,
    Date = sim$Date
  )

  class(res) <- "simScBassSIR"
  return(res)
}


#' @rdname run_scenario
#' @export
compare_scenarios <- function(sim, ..., fn_change = c("Y0", "Y1", "Yt")) {
  scs <- list(...)

  ns <- names(scs)

  if (length(ns) < length(scs)) {
    ns <- paste0("Scenario_", 1:length(scs))
  } else {
    ns <- ifelse(ns == "" | is.null(ns), paste0("Scenario_", 1:length(scs)), ns)
  }

  vs <- intersect(dimnames(sim$Simulations)[[2]], dimnames(scs[[1]]$Simulations)[[2]])
  vs <- vs[vs != "t"]

  traj <- list()

  for (v in vs) {
    sims <- sim$Simulations[, v, ]
    temp <- data.table::data.table(Time = sim$Date, Scenario = "Baseline",
                       mean = apply(sims, 1, mean),
                       lower = apply(sims, 1, quantile, p = 0.025),
                       upper = apply(sims, 1, quantile, p = 0.975))
    for (sc in names(scs)) {
      sims <- scs[[sc]]$Simulations[, v, ]
      temp <- rbind(temp,
                    data.table::data.table(Time = sim$Date, Scenario = sc,
                               mean = apply(sims, 1, mean),
                               lower = apply(sims, 1, quantile, p = 0.025),
                               upper = apply(sims, 1, quantile, p = 0.975))
      )
    }
    traj[[v]] <- temp
  }


  change <- list()
  fn_change <- match.arg(fn_change)

  for (v in vs) {
    baseline <- sim$Simulations[, v, ]

    temp <- data.table::data.table(Time = sim$Date, Scenario = "Baseline",
                       mean = 0, lower = 0, upper = 0)
    for (sc in names(scs)) {
      sims <- scs[[sc]]$Simulations[, v, ]
      sims <- switch (fn_change,
        Y0 = sims / baseline - 1,
        Y1 = 1 - baseline / sims,
        Yt = log(sims / baseline)
      ) * 100
      temp <- rbind(temp,
                    data.table::data.table(Time = sim$Date, Scenario = sc,
                               mean = apply(sims, 1, mean),
                               lower = apply(sims, 1, quantile, p = 0.025),
                               upper = apply(sims, 1, quantile, p = 0.975))
      )
    }
    change[[v]] <- temp
  }

  res <- list(
    Trajectories = traj,
    Changes = change,
    Scenarios = ns
  )
  class(res) <- "compare_scenarios"
  return(res)
}
COVID-19-Modelling/BassSIR documentation built on April 20, 2020, 1:57 a.m.