Forecasting Ne using BNPR

library(phylodyn)
library(dplyr)
library(ggplot2)
library(PILAF)

H3N2_tree <- ape::read.nexus(system.file("extdata/H3N2/MCC.tree",
                                         package = "PILAF"))
plot(ape::ladderize(H3N2_tree),show.tip.label=FALSE)
H3N2_last_time <- 2019.0739726027398
H3N2_root_time <- H3N2_last_time - max(phylodyn::summarize_phylo(H3N2_tree)$coal_times)
ar2crw1_formula <- y ~ -1 + f(time, model = "ar", order = 2) +
      f(week, model = "rw1", cyclic = T, constr = F)
pred <- 20
H3N2_Ne_forecast <- BNPR_forecast(H3N2_tree, last_time = H3N2_last_time, formula = ar2crw1_formula, pred = pred)
BNPR <- BNPR(H3N2_tree)
plot_BNPR(BNPR)
BNPR_PS <- BNPR_PS(H3N2_tree)
plot_BNPR(BNPR_PS)

The next figure compares estimation with the AR(2)+CRW(1) (black curve) to BNPR (red curve)

p_ar2crw1_data <- with(H3N2_Ne_forecast, data.frame(Time = H3N2_last_time - result$summary.random$time$ID,
                                  Median = effpop, 
                                  Lower = effpop025,
                                  Upper = effpop975,
                                  Label = "BNPR AR(2) + cRW(1)"))
BNPR_data <- with(BNPR, data.frame(Time = H3N2_last_time - result$summary.random$time$ID,
                                  Median = effpop, 
                                  Lower = effpop025,
                                  Upper = effpop975)) 

plot_trajectory(p_ar2crw1_data, 
                range(c(p_ar2crw1_data$Lower, p_ar2crw1_data$Upper)), 
                seq(2013, 2019), 
                H3N2_root_time, 
                vline = H3N2_last_time - H3N2_Ne_forecast$result$summary.random$time$ID[pred + 1]) + 
  geom_line(data = BNPR_data, aes(x = Time, y = Median), col = "red") 

This will be the same as above but we will truncate our tree.

truncation_time <- .08 # about 4 weeks
H3N2_truncated <- truncate_data(H3N2_tree, truncation_time, include_trunc = F)
H3N2_Ne_forecast <- BNPR_forecast(H3N2_truncated$tree, last_time = H3N2_last_time - H3N2_truncated$truncation_time, formula = ar2crw1_formula, pred = pred)

p_ar2crw1_data <- with(H3N2_Ne_forecast, data.frame(Time = H3N2_last_time-truncation_time - result$summary.random$time$ID,
                                  Median = effpopmean, 
                                  Lower = effpop025,
                                  Upper = effpop975,
                                  Label = "BNPR AR(2) + cRW(1) truncated"))
plot_trajectory(p_ar2crw1_data, 
                range(c(p_ar2crw1_data$Lower, p_ar2crw1_data$Upper)), 
                seq(2013, 2019), 
                H3N2_root_time, 
                vline = H3N2_last_time - truncation_time - H3N2_Ne_forecast$result$summary.random$time$ID[pred + 1]) +
  geom_line(data = BNPR_data, aes(x = Time, y = Median), col = "red") 

BNPR PS forecast

In both cases H1N1 and H3N2, effective population size estimated with preferential sampling correlates more with ILI Counts. We will forecast H3N2 and H1N1 effective population size trajectories with preferential sampling. In Figure 6, we show that different seasons have different epidemic peaks. Using a weekly cyclic component may not be helpful in this case. The goal is to forecast the next 4 weeks.

H3N2

BNPR_PS_formula <- Y ~ -1 + beta0 +
        f(time, model="ar", order = 2, hyper = hyper, constr = FALSE) +
        f(time2, w, copy="time", fixed=FALSE, param=c(0, beta1_prec))

year_start <- 2015
week_start <- 45
forecast_2015week46 <- forecast_starting(H3N2_tree, last_time = H3N2_last_time, week_start, year_start, label = "H3N2 BNPR PS", formula = BNPR_PS_formula, pred = 5)

truncation_time <- lubridate::decimal_date(compute_truncation_time(year_start, week_start))
tree_trunc <- truncate_data(H3N2_tree, H3N2_last_time - truncation_time, include_trunc = T, n = 2, thresh = 0.1)
BNPR_PS_trunc <- BNPR_PS(tree_trunc$tree)

plot_trajectory(forecast_2015week46$df, 
                range(c(forecast_2015week46$df$Lower, forecast_2015week46$df$Upper)), 
                seq(2013, 2019), 
                H3N2_last_time - max(phylodyn::summarize_phylo(H3N2_tree)$coal_times), 
                vline = forecast_2015week46$res$truncation_time) +
   geom_line(data = data.frame(Time = truncation_time - BNPR_PS_trunc$summary$time,
                               Median = BNPR_PS_trunc$effpop), aes(x = Time, y = Median), col = "red") 

Forecast given intervals

# forecast tasks
seasons <- seq(2014, 2018)
week_start <- c(45, 1, 10)
pred <- 4

tasks <- expand.grid(year = seasons, wk_start = week_start) %>%
  arrange(year, wk_start) %>%
  filter(!(year == 2014 & wk_start < 40)) %>%
  rbind(c(2019, 1)) %>%
  mutate(pred = pred)

forecast_res <- purrr::pmap(tasks, 
                         ~forecast_starting(H3N2_tree, 
                                            last_time = H3N2_last_time, 
                                            week_start = ..2, 
                                            year_start = ..1, 
                                            label = "H3N2_Ne", 
                                            formula = BNPR_PS_formula, 
                                            pred = ..3,
                                            include_trunc = F))

tasks$start <- purrr::pmap(tasks, ~ifelse(..2 < 21, ..1 - 1, ..1))

# reference_res <- purrr::pmap(tasks, ~truncate_BNPR_PS(week_start = ..2, year_start = ..1, H3N2_tree, H3N2_last_time, include_trunc = T, n = 2, thresh = 0.2))
# H3N2_forecasts <- purrr::pmap(list(res = forecast_res, name = 1:length(forecast_res),
#                                    year = tasks$start, ref = reference_res), 
#                                 ~rbind(..1$df[1:pred, ], 
#                                        BNPR_to_df(..4$bnpr, "H3N2_Ne", ..4$trunc_time) %>%
#                                          filter(Time > lubridate::decimal_date(compute_truncation_time(..3, 40)) & Time < as.numeric(..1$df[pred, "Time"]))) %>% 
#                                 mutate(forecast = ..2) %>%
#                                 filter(Time > lubridate::decimal_date(compute_truncation_time(..3, 40)))) %>%
#   dplyr::bind_rows()
# H3N2_forecasts$week <- purrr::map_dbl(H3N2_forecasts$Time, ~ lubridate::week(lubridate::round_date(lubridate::date_decimal(.x))))

H3N2_forecasts <- purrr::pmap(list(res = forecast_res, name = 1:length(forecast_res),
                                   year = tasks$start), 
                                ~mutate(..1$df, forecast = ..2) %>% filter(Time > lubridate::decimal_date(compute_truncation_time(..3, 39)))) %>%
  dplyr::bind_rows()
H3N2_forecasts$week <- purrr::map_dbl(H3N2_forecasts$Time, ~ lubridate::week(lubridate::round_date(lubridate::date_decimal(.x))))

Overlaying the visualization on BNPR PS (no forecast)

selected_seasons <- c("2014-2015", "2015-2016", "2016-2017", "2017-2018", "2018-2019")
H3N2_Ne_forecast <- BNPR_forecast(H3N2_tree, last_time = H3N2_last_time, formula = BNPR_PS_formula, pred = 0, pref = T)
H3N2_Ne_data <- with(H3N2_Ne_forecast, 
                     data.frame(time = H3N2_last_time - result$summary.random$time$ID,
                                week = weeks,
                                H3N2_Ne = effpopmean,
                                lwr = effpop025,
                                upr = effpop975))
H3N2_Ne_data$year <- floor(H3N2_Ne_data$time)

p_H3N2_Ne_forecast <- visualize_flu_season(H3N2_Ne_data, time_series_names = "H3N2_Ne", 
                                  subset_seasons = selected_seasons, ncol = 5, season_label = F,
                                  error_band_names = data.frame(lwr = "lwr", upr = "upr", stringsAsFactors = F),
                                  add_forecasts = H3N2_forecasts) +
  geom_vline(aes(xintercept = 6), linetype = "dashed", alpha = 0.3) + # week 45
  geom_vline(aes(xintercept = 15), linetype = "dashed", alpha = 0.3) + # week 1
  geom_vline(aes(xintercept = 24), linetype = "dashed", alpha = 0.3) # week 10
  # levels are c(40:53, 1:21)
p_H3N2_Ne_forecast

H1N1

H1N1_tree <- ape::read.nexus(system.file("extdata/H1N1/MCC.tree",
                                         package = "PILAF"))
H1N1_last_time <- 2019.0931506849315

forecasts <- purrr::pmap(tasks, 
                         ~forecast_starting(H1N1_tree, 
                                            last_time = H1N1_last_time, 
                                            week_start = ..2, 
                                            year_start = ..1, 
                                            label = "H1N1_Ne", 
                                            formula = BNPR_PS_formula, 
                                            pred = ..3,
                                            include_trunc = F))

# reference_res <- purrr::pmap(tasks, ~truncate_BNPR_PS(week_start = ..2, year_start = ..1, H1N1_tree, H1N1_last_time, include_trunc = T, n = 1, thresh = 1))
# H1N1_forecasts <- purrr::pmap(list(res = forecasts, name = 1:length(forecasts),
#                                    year = tasks$start, ref = reference_res), 
#                                 ~rbind(..1$df[1:pred, ], 
#                                        BNPR_to_df(..4$bnpr, "H1N1_Ne", ..4$trunc_time) %>%
#                                          filter(Time > lubridate::decimal_date(compute_truncation_time(..3, 40)) & Time < as.numeric(..1$df[pred, "Time"]))) %>% 
#                                 mutate(forecast = ..2)) %>%
#   dplyr::bind_rows()
# H1N1_forecasts$week <- purrr::map_dbl(H1N1_forecasts$Time, ~lubridate::week(lubridate::round_date(lubridate::date_decimal(.x))))

H1N1_forecasts <- purrr::pmap(list(res = forecasts, name = 1:length(forecasts),
                                   year = tasks$start), 
                                ~mutate(..1$df, forecast = ..2) %>% filter(Time > lubridate::decimal_date(compute_truncation_time(..3, 39)))) %>%
  dplyr::bind_rows()
H1N1_forecasts$week <- purrr::map_dbl(H1N1_forecasts$Time, ~lubridate::week(lubridate::round_date(lubridate::date_decimal(.x))))

H1N1_Ne_forecast <- BNPR_forecast(H1N1_tree, last_time = H1N1_last_time, formula = BNPR_PS_formula, pred = 0, pref = T)

H1N1_Ne_data <- with(H1N1_Ne_forecast, 
                     data.frame(time = H1N1_last_time - result$summary.random$time$ID,
                                week = weeks,
                                H1N1_Ne = effpopmean,
                                lwr = effpop025,
                                upr = effpop975))
H1N1_Ne_data$year <- floor(H1N1_Ne_data$time)
p_H1N1_Ne_forecast <- visualize_flu_season(H1N1_Ne_data, time_series_names = "H1N1_Ne", 
                                  subset_seasons = selected_seasons, ncol = 5, season_label = F,
                                  error_band_names = data.frame(lwr = "lwr", upr = "upr", stringsAsFactors = F),
                                  add_forecasts = H1N1_forecasts) +
  geom_vline(aes(xintercept = 6), linetype = "dashed", alpha = 0.3) + # week 45
  geom_vline(aes(xintercept = 15), linetype = "dashed", alpha = 0.3) + # week 1
  geom_vline(aes(xintercept = 24), linetype = "dashed", alpha = 0.3) # week 10
p_H1N1_Ne_forecast
theme_no_x <-  theme(
  axis.title.x = element_blank(), 
  axis.line.x = element_blank(),
  axis.ticks.x = element_blank(),
  axis.text.x = element_blank())
p <- cowplot::plot_grid(p_H3N2_Ne_forecast + ylab("H3N2 BNPR PS") + theme_no_x +  theme(strip.text = element_blank()), 
                   p_H1N1_Ne_forecast + ylab("H1N1 BNPR PS"), 
                   ncol = 1,
                   align = "v")
ggsave(p, file = "../../projectingvirus/manuscript/Figures/data/BNPR_PS_forecast_AR2.pdf", width = 9, height = 5)


Mamie/PILAF documentation built on May 8, 2019, 8:53 p.m.