This document explores performance of different ensemble methods for forecasting incidence near local peaks.

# load packages
library(covidData)
library(covidEnsembles)
library(tidyverse)
library(gridExtra)
library(knitr)
library(DT)

knitr::opts_chunk$set(echo = FALSE, cache.lazy = FALSE)
options(width = 200)

ggplot2::theme_set(new = theme_bw())

#setwd("code/application/retrospective-qra-comparison/analyses/retrospective-peak-skill/")
# load data

# dates for "truth" data used to compute scores and used in plots
jhu_issue_date <- max(covidData::jhu_deaths_data$issue_date)
healthdata_issue_date <- max(covidData::healthdata_hosp_data$issue_date)

# load data
observed_deaths <-
  covidData::load_jhu_data(
    issue_date = jhu_issue_date,
    spatial_resolution = c("state", "national"),
    temporal_resolution = "weekly",
    measure = "deaths") %>%
  tidyr::pivot_longer(
    cols = c("inc", "cum"),
    names_to = "target_variable",
    values_to = "observed"
  ) %>%
  dplyr::transmute(
    location = location,
    target_variable = paste0("wk ahead ", target_variable, " death"),
    target_end_date = as.character(date),
    observed = observed
  )

observed_deaths <- observed_deaths[!duplicated(observed_deaths), ]

observed_cases <-
  covidData::load_jhu_data(
    issue_date = jhu_issue_date,
    spatial_resolution = c("county", "state", "national"),
    temporal_resolution = "weekly",
    measure = "cases") %>%
  tidyr::pivot_longer(
    cols = c("inc", "cum"),
    names_to = "target_variable",
    values_to = "observed"
  ) %>%
  dplyr::transmute(
    location = location,
    target_variable = paste0("wk ahead ", target_variable, " case"),
    target_end_date = as.character(date),
    observed = observed
  )

observed_cases <- observed_cases[!duplicated(observed_cases), ]

observed_hosps <-
  covidData::load_healthdata_data(
    issue_date = healthdata_issue_date,
    spatial_resolution = c("state", "national"),
    temporal_resolution = "daily",
    measure ="hospitalizations") %>%
  tidyr::pivot_longer(
    cols = c("inc", "cum"),
    names_to = "target_variable",
    values_to = "observed"
  ) %>%
  dplyr::transmute(
    location = location,
    target_variable = paste0("day ahead ahead ", target_variable, " hosp"),
    target_end_date = as.character(date),
    observed = observed
  )

observed_hosps <- observed_hosps[!duplicated(observed_hosps), ]

observed <- dplyr::bind_rows(observed_deaths, observed_cases, observed_hosps)
# function to extract model identifiers from abbreviation
parse_model_case <- function(model_abbr) {
  case_parts <- strsplit(model_abbr, split = "-")[[1]]
  purrr::map_dfc(
    case_parts,
    function(case_part) {
      nc <- nchar(case_part)
      if(substr(case_part, 1, min(nc, 9)) == "intercept") {
        return(data.frame(
          intercept = as.logical(substr(case_part, 11, nc))
        ))
      } else if(substr(case_part, 1, min(nc, 14)) == "combine_method") {
        return(data.frame(
          combine_method = substr(case_part, 16, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 11)) == "missingness") {
        return(data.frame(
          missingness = substr(case_part, 13, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 15)) == "quantile_groups") {
        return(data.frame(
          quantile_groups = substr(case_part, 17, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 11)) == "window_size") {
        return(data.frame(
          window_size = substr(case_part, 13, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 27)) ==
          "check_missingness_by_target") {
        return(data.frame(
          check_missingness_by_target = substr(case_part, 29, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 18)) == "do_standard_checks") {
        return(data.frame(
          do_standard_checks = substr(case_part, 20, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 17)) == "do_baseline_check") {
        return(data.frame(
          do_baseline_check = substr(case_part, 19, nc)
        ))
      } else if(substr(case_part, 1, min(nc, 16)) == "estimation_scale") {
        return(data.frame(
          estimation_grouping = substr(case_part, 18, nc)
        ))
      } else {
        message("Unsupported case part")
      }
    }
  )
}
# load scores
all_scores <- readRDS("../retrospective-scores/retrospective_scores.rds")

model_cases <- suppressMessages(purrr::map_dfr(
    unique(all_scores$model),
    parse_model_case
  )) %>%
  dplyr::mutate(
    model = unique(all_scores$model)
  )

all_scores <- all_scores %>%
  dplyr::left_join(model_cases, by = "model") %>%
  dplyr::mutate(
    model_brief = paste(
      combine_method,
      "window",
      window_size,
      quantile_groups,
      estimation_grouping,
      sep = "_"
    )
  ) %>%
  dplyr::filter(
    !grepl("ensemble_switching", model),
    combine_method != "positive"
  ) %>%
  dplyr::mutate(
    spatial_scale = ifelse(
      location == "US",
      "National",
      ifelse(
        nchar(location) == 2,
        "State",
        "County"
      )
    )  ) %>%
  dplyr::arrange(
    combine_method,
    as.integer(window_size),
    quantile_groups,
    estimation_grouping
  )

all_models <- unique(all_scores$model_brief)
all_scores$model_brief <- factor(all_scores$model_brief, levels = all_models)
# window_10_model_inds <- grepl("window_10", all_models)
# new_levels <- c(
#   all_models[!window_10_model_inds],
#   all_models[window_10_model_inds])
# subset scores to those that are comparable for all models within each
# combination of spatial scale and base target
# only among those models with any forecasts for that combination
all_scores_common_by_target_variable_spatial_scale <-
  purrr::pmap_dfr(
    all_scores %>%
      distinct(target_variable, spatial_scale),
    function(target_variable, spatial_scale) {
#      browser()
      reduced_scores <- all_scores %>%
        dplyr::filter(
          target_variable == UQ(target_variable),
          spatial_scale == UQ(spatial_scale)
        )

      # subset to same forecasts made for each ensemble method
      scores_to_keep <- reduced_scores %>%
        dplyr::select(model, forecast_date, location, horizon, target_variable, abs_error) %>%
        tidyr::pivot_wider(
          names_from = "model", values_from = "abs_error"
        )
      all_models <- unique(reduced_scores$model)
      scores_to_keep$keep <-
        apply(scores_to_keep[all_models], 1, function(x) all(!is.na(x)))

      # message(paste0(
      #   "at ", spatial_scale, " for ", target_variable,
      #   ", missing forecasts for models: ",
      #   paste0(
      #     all_models[apply(scores_to_keep[all_models], 2, function(x) any(is.na(x)))]
      #   )
      # ))

      scores_to_keep <- scores_to_keep %>%
        dplyr::select(forecast_date, location, horizon, target_variable, keep)

      dplyr::left_join(
        reduced_scores,
        scores_to_keep,
        by = c("forecast_date", "location", "horizon", "target_variable")
      ) %>%
        dplyr::filter(keep) %>%
        dplyr::select(-keep)
    }
  )
# last target date to evaluate:
#  - most recent Saturday with observed data for weekly targets
#  - most recent day with observed data for daily targets
last_weekly_target_date <- max(observed_deaths$target_end_date)
last_daily_target_date <- max(observed_hosps$target_end_date)

# dates for saturdays included in the analysis:
#  - we consider ensemble forecasts generated 2 days after this saturday date
#  - week ahead targets are defined relative to this saturday date
first_forecast_week_end_date <- lubridate::ymd("2020-05-09")
last_forecast_week_end_date <- lubridate::ymd(last_weekly_target_date) - 7
num_forecast_weeks <- as.integer(last_forecast_week_end_date -
                         first_forecast_week_end_date) / 7 + 1
forecast_week_end_dates <- as.character(
  first_forecast_week_end_date +
    seq(from = 0, length = num_forecast_weeks) * 7
)

# Dates of forecast submission for forecasts included in this analysis:
# 2 days after the saturdays
forecast_dates <- lubridate::ymd(forecast_week_end_dates) + 2


# load forecasts

# targets
all_targets <- c(
  paste0(1:4, " wk ahead cum death"),
  paste0(1:4, " wk ahead inc death"),
  paste0(1:4, " wk ahead inc case"),
  paste0(1:28, " day ahead inc hosp")
)


all_forecasts <- purrr::map_dfr(
#  c("national", "state", "state_national", "county"),
  c("state"),
  function(spatial_scale) {
    # Path to forecasts to evaluate
    submissions_root <- paste0(
      "../../retrospective-forecasts/",
      spatial_scale, "/"
    )

    # models to read in
    model_abbrs <- list.dirs(submissions_root, full.names = FALSE)
    model_abbrs <- model_abbrs[nchar(model_abbrs) > 0]

    model_abbrs <- c(
      "intercept_FALSE-combine_method_median-missingness_by_location_group-quantile_groups_per_model-window_size_0-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE",
      "intercept_FALSE-combine_method_convex-missingness_impute-quantile_groups_per_model-window_size_8-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE"
    )

    if (spatial_scale == "county") {
      model_abbrs <- c(
        "intercept_FALSE-combine_method_convex-missingness_impute-quantile_groups_per_quantile-window_size_3-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE",
        "intercept_FALSE-combine_method_convex-missingness_impute-quantile_groups_per_quantile-window_size_4-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE",
        "intercept_FALSE-combine_method_ew-missingness_by_location_group-quantile_groups_per_model-window_size_0-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE",
        "intercept_FALSE-combine_method_median-missingness_by_location_group-quantile_groups_per_model-window_size_0-check_missingness_by_target_FALSE-do_standard_checks_FALSE-do_baseline_check_FALSE"
      )
    }

    if (spatial_scale %in% c("national", "state_national")) {
      response_vars <- c("cum_death", "inc_death", "inc_case")
    } else if (spatial_scale == "state") {
      response_vars <- c("cum_death", "inc_death", "inc_case", "inc_hosp")
    } else if (spatial_scale == "county") {
      response_vars <- "inc_case"
    }

    spatial_scale_forecasts <- purrr::map_dfr(
      response_vars,
      function(response_var) {
        if (response_var %in% c("inc_death", "cum_death")) {
          required_quantiles <-
            c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
          temporal_resolution <- "wk"
          horizon <- 4L
          targets <-
            paste0(1:horizon, " wk ahead ", gsub("_", " ", response_var))
          all_locations <- unique(observed_deaths$location)
        } else if (response_var == "inc_case") {
          required_quantiles <-
            c(0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975)
          temporal_resolution <- "wk"
          horizon <- 4L
          targets <- paste0(
            1:horizon, " wk ahead ", gsub("_", " ", response_var))
          all_locations <- unique(observed_cases$location)
        } else if (response_var == "inc_hosp") {
          required_quantiles <-
            c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
          temporal_resolution <- "day"
          horizon <- 28L
          targets <- paste0(
            1:(horizon + 6), " day ahead ", gsub("_", " ", response_var))
          all_locations <- unique(observed_hosps$location)
        }

        load_covid_forecasts_relative_horizon(
          monday_dates = forecast_dates,
          model_abbrs = model_abbrs,
          timezero_window_size = 6,
          locations = all_locations,
          targets = targets,
          horizon = horizon,
          required_quantiles = required_quantiles,
          submissions_root = submissions_root,
          include_null_point_forecasts = FALSE,
          keep_last = FALSE
        )
      }
    ) %>%
      dplyr::mutate(spatial_scale = spatial_scale)

    return(spatial_scale_forecasts)
  }
)

all_forecasts <- all_forecasts %>%
  dplyr::mutate(
    target_variable = substr(target, regexpr(" ", target) + 1, nchar(target)),
    model = paste0(model, "-estimation_scale_", spatial_scale)
  )

all_forecasts <- all_forecasts %>%
  dplyr::left_join(model_cases, by = "model") %>%
  dplyr::mutate(
    model_brief = paste(
      combine_method,
      "window",
      window_size,
      quantile_groups,
      estimation_grouping,
      sep = "_"
    )
  )

Performance Near Local Peaks for State Level Incident Cases

We're going to look at how well ensemble methods did for forecasting incident cases near local peaks, in states that have had a local peak recently. For now, this is a manually curated list. All locations with at least two weeks of observed data after a recent local peak are included:

locations_to_keep <- c(
  #"02",
  "08", "17", as.character(19:20),
  #"27",
  as.character(29:31), "35", "38", "40",
  #"41",
  "46", "49",
  "55", "56", "66")

observed_to_examine <- observed_cases %>%
  dplyr::filter(
    location %in% locations_to_keep,
    target_variable == "wk ahead inc case"
  ) %>%
  dplyr::left_join(covidData::fips_codes, by = "location")

ggplot(data = observed_to_examine) +
  geom_line(
    mapping = aes(
      x = lubridate::ymd(target_end_date),
      y = observed,
      group = abbreviation
    )) +
  facet_wrap(~ abbreviation, scales = "free_y")

Plots per location

For each of the above locations, we display four plots:

  1. A display of the model weights with the timing of the state-specific peak incident cases indicated with a vertical line. Note that for this model, the estimated ensemble weights are the same across all states. In practice there may be some minor differences across states due to different patterns in missing forecasts from different models; these differences are not shown here.
  2. The mean WIS for the median ensemble and the trained ensemble, averaged across all prospective week-ahead forecasts for which the outcome has been observed at the time of the report generation. Forecasts are subset to the common weeks available for all ensemble methods we considered, which means scores are available for weeks determined by the cut off for the ensemble with a window size of 10 weeks. The horizontal axis is the forecast date (a Monday of submission), not the target end date. For later target end dates, only scores for the shorter forecast horizons are available. The vertical line is located at the timing of the state-specific peak incidence.
  3. The observed incident cases for that state, with a vertical line at the local peak. The observed data are weekly incident case counts as of the Saturday ending each week.
  4. Finally, a separate facetted plot shows the forecasts from each ensemble approach during the two weeks before the local peak, the week of the local peak, and the week after the local peak.

The questions we hope these plots can help answer are:

  1. How does each method perform in the weeks immediately before and after a local peak? There is some variation across locations, but our sense is that immediately before a peak the trained ensemble is generally more aggressive in predicting a continued rise in incidence. This leads to better scores during the rise, but worse scores at the time of the peak. However, there is a quick recovery after the peak and the two methods are generally pretty similar immediately after the peak.
  2. Overall, in the locations where we have seen a peak, what is the relative ranking of the methods? For most locations, the trained approach is better than the median when averaging across all scored weeks. In terms of mean WIS, the improvements during the rise in incidence offset the penalty incurred for over prediction at the peak.
weights <-
  readRDS("../retrospective-weight-estimates/retrospective_weight_estimates.rds") %>%
    dplyr::filter(location == "01", window_size == 8,
      target_variable == "inc_case") %>%
    mutate(model = reorder(model, -weight, FUN = sum)) %>%
    left_join(covidHubUtils::hub_locations, by = c("location" = "fips"))

# p <- weights %>%
#   ggplot(aes(x = forecast_date, y = weight, fill = model)) +
#     geom_bar(stat = "identity") +
#     facet_grid(spatial_resolution ~ target_variable)

#location <- locations_to_keep[1]
for(location in locations_to_keep) {
  observed <- observed_to_examine %>%
    dplyr::filter(
      location == UQ(location),
      target_end_date >= "2020-06-01"
    ) %>%
    dplyr::transmute(
      model = "Observed Data (JHU)",
      target_variable = "inc case",
      target_end_date = lubridate::ymd(target_end_date),
      location = location,
      value = observed,
      geo_type = "state",
      location_name = location_name,
      abbreviation = abbreviation
    )

  peak_week <- observed$target_end_date[which.max(observed$value)]

  location_scores_per_week <- suppressMessages(
    all_scores_common_by_target_variable_spatial_scale %>%
      dplyr::filter(
        model_brief %in% unique(all_forecasts$model_brief),
        location == UQ(location),
        target_variable == "wk ahead inc case"
      ) %>%
      dplyr::mutate(
        forecast_date = forecast_week_end_date + 2
      ) %>%
      dplyr::group_by(model_brief, forecast_date) %>%
      dplyr::summarise(wis = mean(wis))
  )

  location_scores_overall <- suppressMessages(
    all_scores_common_by_target_variable_spatial_scale %>%
      dplyr::filter(
        model_brief %in% unique(all_forecasts$model_brief),
        location == UQ(location),
        target_variable == "wk ahead inc case"
      ) %>%
      dplyr::group_by(model_brief) %>%
      dplyr::summarise(wis = mean(wis))
  )

  p <- ggplot() +
    geom_line(
      data = dplyr::bind_rows(
        observed %>%
          dplyr::mutate(
            type = "Observed Data",
            model_brief = "Observed Data",
            forecast_date = target_end_date
          ),
        location_scores_per_week %>%
          dplyr::mutate(type = "Ensemble WIS", value = wis)
      ),
      mapping = aes(x = forecast_date, y = value, color = model_brief)
    ) +
    geom_hline(
      data = location_scores_overall %>%
        dplyr::mutate(type = "Ensemble WIS"),
      mapping = aes(yintercept = wis, color = model_brief),
      linetype = 2
    ) +
    geom_bar(
      data = weights %>% dplyr::mutate(type = "Component Weights"),
      mapping = aes(x = forecast_date, y = weight, fill = model),
      stat = "identity"
    ) +
    geom_vline(
      data = data.frame(
        type = c("Ensemble WIS", "Observed Data", "Component Weights"),
        peak_date = peak_week,
        stringsAsFactors = FALSE
      ),
      mapping = aes(xintercept = peak_date)
    ) +
    scale_color_manual(
      "Model or Data",
      values = c(
        "Observed Data" = "black",
        "median_window_0_per_model_state" = "orange",
        "convex_window_8_per_model_state" = "cornflowerblue"
      )
    ) +
    facet_wrap( ~ type, ncol = 1, scales = "free_y") +
    ggtitle(paste0("Component Model Weights, Mean Ensemble WIS and Observed Data: ", observed$location_name[1]))
  print(p)

  forecasts <- all_forecasts %>%
    dplyr::filter(
      location == UQ(location),
      grepl("wk ahead inc case", target)
    ) %>%
    dplyr::transmute(
      model = model_brief,
      forecast_date = timezero,
      location = location,
      location_name = location_name,
      geo_type = "state",
      horizon = as.integer(horizon),
      temporal_resolution = "wk",
      target_variable = "inc case",
      target_end_date = target_end_date,
      type = "quantile",
      quantile = as.numeric(quantile),
      value = value
    )
  forecasts <- dplyr::bind_rows(
    forecasts,
    forecasts %>%
      dplyr::filter(quantile == 0.5) %>%
      dplyr::mutate(
        type = "point",
        quantile = NA
      )
  )
    #,
    #   grepl("combine_method_median", model) |
    #   (combine_method == "convex", model) &
    #     grepl("estimation_grouping_state", model) &
    #     grepl("quantile_groups_per_model", model) &
    #     grepl("window_size_8", model))
    # )

  covidHubUtils::plot_forecast(
    forecast_data = forecasts %>%
      dplyr::filter(
        abs(forecast_date - peak_week) <= 14
      ),
    target_variable = "inc case",
    truth_data = observed,
    truth_source = "JHU",
    intervals = c(.5, .8, .95),
    fill_by_model = TRUE,
    facet = . ~ model + forecast_date,
    facet_nrow = 2
  )# +
#    scale_x_date(name=NULL, date_breaks = "1 months", date_labels = "%b") +
#    theme(axis.ticks.length.x = unit(0.5, "cm"),
#      axis.text.x = element_text(vjust = 7, hjust = -0.2))
}


reichlab/covidEnsembles documentation built on Jan. 31, 2024, 7:21 p.m.