knitr::opts_chunk$set(echo = TRUE)
library(tidyverse)
library(rstan)

Simulate from sird model for a single location based on fixed parameter values

covid <- read_csv('../../../../data-raw/jhu-incident.csv')
pops <- read_csv('../../../../data-raw/location_traits.csv')
rstan_options(auto_write = TRUE)

ny_covid <- covid %>%
  dplyr::filter(location_abbreviation=='NY')

predictive_model <- stan_model(file = "predictive_sird_one_location.stan")

predictive_sample <- stan(
  file = "predictive_sird_one_location.stan",
  data = list(
    N = pops$totalpop[pops$postalCode=='NY'],
    T = nrow(ny_covid),
    y = ny_covid$value,
    d0 = 0.0,
    raw_theta = rep(0.0, 3L),
    raw_theta_mean = c(0.4, -0.7, -2.3),
    raw_theta_sd = c(0.5, 0.5, 0.5),
    raw_state_init = rep(0.0, 2L),
    raw_state_init_mean = c(5.0, 1.0),
    raw_state_init_sd = rep(1.0, 2L),
    raw_phi = 0.0,
    raw_phi_mean = 1.0,
    raw_phi_sd = 1.0
  ),
  iter = 1,
  chains = 1,
  seed = 894711,
  algorithm = "Fixed_param",
  verbose = TRUE)
run_sir <- function(parameters,t){
  sir <- function(time, state, parameters) {
    with(as.list(c(state, parameters)), {
      dS <- -beta * S * I
      dI <-  beta * S * I - gamma * I - u*I
      dD <-  u*I
      dR <-  gamma * I
      return(list(c(dS, dI, dD, dR)))
    })
  }
  ### Set parameters
  ## Proportion in each compartment: Susceptible 0.999999, Infected 0.000001, Recovered 0
  init       <- c(S = 1-1e-6, I = 1e-6, D = 0.0 ,R = 0.0)
  ## beta: infection parameter; gamma: recovery parameter
  ## Time frame
  times      <- seq(0, t, by = 1)
  ## Solve using ode (General Solver for Ordinary Differential Equations)
  out <- ode(y = init, times = times, func = sir, parms = parameters)
  ## change to data frame
  out <- data.frame(out)
  return (out$D)
}

observedData =  run_sir(c(beta=.9,gamma=.25,u=.001),10)
fit_sample <- stan(
  file = "sird_one_location.stan",
  data = list(
    N = 19616658.0,
    T = 18,
    y = c(0, 0, 0, 0, 0, 0, 0, 6, 189, 1393, 4403, 6507, 6262, 3249, 2189, 2414, 1437, 982),
    d0 = 0.0
  ),
  init = list(list(
    raw_theta = c(0.0, 0.0, 0.0),
#    raw_theta_mean = c(0.33, -0.7, -7.5),
#    raw_theta_sd = c(0.1, 0.1, 0.1),
    raw_state_init = c(0.0, 0.0),
#    raw_state_init_mean = c(7.0, 0.0),
 #   raw_state_init_sd = c(0.1, 0.1),
    raw_phi = 0.0#,
 #   raw_phi_mean = 1.0,
 #   raw_phi_sd = 0.0
  )),
  iter = 1000,
  chains = 1,
  seed = 894711,
  verbose = TRUE)
preds <- extract(fit_sample, 'y_pred')

preds_df <- purrr::map_dfr(
  seq_len(500),
  function(i) {
    purrr::map_dfr(
      seq_len(100),
      function(j) {
        data.frame(
          week = seq_len(22),
          y_hat = preds$y_pred[i, 1:22, j]
        )
      }
    )
  }
)
ggplot() +
  geom_point(
    data = preds_df, mapping = aes(x = week, y = y_hat, alpha = 0.1), color = 'blue'
  ) +
  geom_point(
    data = ny_covid,
    mapping = aes(x = as.numeric(week) - 3, y = value),
    color = "orange"
  ) +
  geom_line(
    data = ny_covid,
    mapping = aes(x = as.numeric(week) - 3, y = value),
    color = "orange"
  ) +
  theme_bw()
state_hat <- extract(fit_sample, 'state_hat')
dim(state_hat$state_hat)
fit_sample <- stan(
  file = "seird_one_location.stan",
  data = list(
    N = 19616658.0,
    T = 18,
    y = c(0, 0, 0, 0, 0, 0, 0, 6, 189, 1393, 4403, 6507, 6262, 3249, 2189, 2414, 1437, 982),
    d0 = 0.0
  ),
  init = list(list(
    raw_theta = c(0.0, 0.0, 0.0, 0.0),
#    raw_theta_mean = c(0.33, -0.7, -7.5),
#    raw_theta_sd = c(0.1, 0.1, 0.1),
    raw_state_init = c(0.0, 0.0)#,
#    raw_state_init_mean = c(7.0, 0.0),
 #   raw_state_init_sd = c(0.1, 0.1),
#    raw_phi = 0.0#,
 #   raw_phi_mean = 1.0,
 #   raw_phi_sd = 0.0
  )),
  iter = 1000,
  chains = 1,
  seed = 894711,
  verbose = TRUE)
preds <- extract(fit_sample, 'y_pred')

preds_df <- purrr::map_dfr(
  seq_len(500),
  function(i) {
    purrr::map_dfr(
      seq_len(100),
      function(j) {
        data.frame(
          week = seq_len(22),
          y_hat = preds$y_pred[i, 1:22, j]
        )
      }
    )
  }
)
ggplot() +
  geom_point(
    data = preds_df, mapping = aes(x = week, y = y_hat, alpha = 0.1), color = 'blue'
  ) +
  geom_point(
    data = ny_covid,
    mapping = aes(x = as.numeric(week) - 3, y = value),
    color = "orange"
  ) +
  geom_line(
    data = ny_covid,
    mapping = aes(x = as.numeric(week) - 3, y = value),
    color = "orange"
  ) +
  theme_bw()


reichlab/covidEnsembles documentation built on Jan. 31, 2024, 7:21 p.m.