code/application/weekly-ensemble/build_4_week_ensembles.R

message("starting build_4_week_ensembles.R")

library(tidyverse)
library(zeallot)
library(covidHubUtils)
library(covidEnsembles)
library(covidData)
library(googledrive)
library(yaml)
library(here)
options(error = recover)
setwd(here())

final_run <- TRUE

# Where to find component model submissions
submissions_root <- '../covid19-forecast-hub/data-processed/'

# Parent dir of this which is the hub clone path used by
# covidHubUtils::load_latest_forecasts for loading locally
hub_repo_path <- '../covid19-forecast-hub/'

# Where to save ensemble forecasts
root <- c('code/application/weekly-ensemble/forecasts/')
if (!file.exists(root)) dir.create(root, recursive = TRUE)
if (!file.exists(paste0(root,"ensemble-metadata/"))) {
  dir.create(paste0(root,"ensemble-metadata/"), recursive = TRUE)
}

# Where to save plots
plots_root <- 'code/application/weekly-ensemble/plots/COVIDhub-4_week_ensemble/'
if (!file.exists(plots_root)) dir.create(plots_root, recursive = TRUE)

# Where to save hospitalization exclusion tables
sd_check_table_path <- 'code/application/weekly-ensemble/exclusion-outputs/tables/'
if (!file.exists(sd_check_table_path)) dir.create(sd_check_table_path, recursive = TRUE)
sd_check_plot_path <- 'code/application/weekly-ensemble/exclusion-outputs/plots/'
if (!file.exists(sd_check_plot_path)) dir.create(sd_check_plot_path, recursive = TRUE)

# List of candidate models for inclusion in ensemble
candidate_model_abbreviations_to_include <- get_candidate_models(
  submissions_root = submissions_root,
  include_designations = c("primary", "secondary"),
  include_COVIDhub_ensemble = FALSE,
  include_COVIDhub_baseline = TRUE)

# Drop hospitalizations ensemble from JHU APL and ensemble from FDANIHASU
candidate_model_abbreviations_to_include <-
  candidate_model_abbreviations_to_include[
    !(candidate_model_abbreviations_to_include %in% c("JHUAPL-SLPHospEns", "FDANIHASU-Sweight", "COVIDhub-trained_ensemble", "KITmetricslab-select_ensemble"))
  ]


# Figure out what day it is; forecast creation date is set to a Monday,
# even if we are delayed and create it Tuesday morning.
forecast_date <- lubridate::floor_date(Sys.Date(), unit = "week") + 1

response_vars <- c("inc_hosp")
for (response_var in response_vars) {
  message("starting setup of ", response_var, " for 4-wk ensemble")
  if (response_var == "cum_death") {
    do_q10_check <- do_nondecreasing_quantile_check <- TRUE
    do_sd_check <- "exclude_none"
    required_quantiles <-
      c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
    spatial_resolution <- c("state", "national")
    temporal_resolution <- "wk"
    horizon <- 4L
    targets <- paste0(1:horizon, " wk ahead ", gsub("_", " ", response_var))
    forecast_week_end_date <- forecast_date - 2

    # date for which retrieved deaths truth data should be current
    data_as_of_date <- covidData::available_issue_dates("deaths") %>% max()

  } else if (response_var == 'inc_death') {
    do_q10_check <- do_nondecreasing_quantile_check <- FALSE
    do_sd_check <- "exclude_none"
    required_quantiles <-
      c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
    spatial_resolution <- c("state", "national")
    temporal_resolution <- "wk"
    horizon <- 4L
    targets <- paste0(1:horizon, " wk ahead ", gsub("_", " ", response_var))
    forecast_week_end_date <- forecast_date - 2

    # date for which retrieved deaths truth data should be current
    # repeated from cum_death block for clarity
    data_as_of_date <- covidData::available_issue_dates("deaths") %>% max()

  } else if (response_var == "inc_case") {
    do_q10_check <- do_nondecreasing_quantile_check <- FALSE
    do_sd_check <- "exclude_none"
    required_quantiles <- c(0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975)
    spatial_resolution <- c('county', 'state', 'national')
    temporal_resolution <- "wk"
    horizon <- 4L
    targets <- paste0(1:horizon, " wk ahead ", gsub("_", " ", response_var))
    forecast_week_end_date <- forecast_date - 2

    # date for which retrieved cases truth data should be current
    data_as_of_date <- covidData::available_issue_dates("cases") %>% max()

  } else if (response_var == "inc_hosp") {
    do_q10_check <- do_nondecreasing_quantile_check <- FALSE
    do_sd_check <- "exclude_none"
    required_quantiles <- c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
    spatial_resolution <- c("state", "national")
    temporal_resolution <- "day"
    horizon <- 28L
    targets <- paste0(1:(horizon + 6), " day ahead ", gsub("_", " ", response_var))
    forecast_week_end_date <- forecast_date

    # date for which retrieved hospitalization truth data should be current
    data_as_of_date <- covidData::available_issue_dates("hospitalizations") %>% max()
  }

  message("starting generation of full 4-week ensemble")

  c(model_eligibility, wide_model_eligibility, location_groups, component_forecasts) %<-%
    build_covid_ensemble(
      hub = "US",
      source = "local_hub_repo",
      hub_repo_path = hub_repo_path,
      candidate_model_abbreviations_to_include =
        candidate_model_abbreviations_to_include,
      spatial_resolution = spatial_resolution,
      targets = targets,
      forecast_date = forecast_date,
      forecast_week_end_date = forecast_week_end_date,
      max_horizon = horizon,
      timezero_window_size = 6,
      window_size = 0,
      data_as_of_date = data_as_of_date,
      intercept = FALSE,
      combine_method = 'median',
      quantile_groups = rep(1, 23),
      missingness = 'by_location_group',
      backend = NA,
      required_quantiles = required_quantiles,
      do_q10_check = do_q10_check,
      do_nondecreasing_quantile_check = do_nondecreasing_quantile_check,
      do_baseline_check = FALSE,
      do_sd_check = do_sd_check, # implement CDC exclusion requests
      sd_check_table_path = sd_check_table_path,
      sd_check_plot_path = sd_check_plot_path,
      manual_eligibility_adjust = NULL,
      return_eligibility = TRUE,
      return_all = TRUE
    )

  # subset ensemble forecasts to only locations where more than 1
  # component model contributed.
  model_counts <- apply(
    location_groups %>% select_if(is.logical),
    1,
    sum)
  location_groups <- location_groups[model_counts > 1, ]
  ensemble_predictions <- bind_rows(location_groups[['qra_forecast']])

  ## instrumentation to try to identify problems, added Jan 24 2023
  message("starting possibly problematic set of operations")
  print(colnames(ensemble_predictions))
  print(nrow(ensemble_predictions))
  
  # save the results in required format
  formatted_ensemble_predictions <- ensemble_predictions %>%
    left_join(
      fips_codes,# %>% select(location, location_name = location_abbreviation),
      by='location') %>%
    dplyr::transmute(
      forecast_date = UQ(forecast_date),
      target = target,
      target_end_date = as.character(
          lubridate::ymd(forecast_week_end_date) +
            as.numeric(substr(target, 1, regexpr(" ", target, fixed = TRUE) - 1)) *
              ifelse(grepl("day", target), 1, 7)
        ),
      location = location,
      location_name = location_name,
      type = 'quantile',
      quantile = quantile,
      value = ifelse(
        quantile < 0.5,
        floor(value),
        ifelse(
          quantile == 0.5,
          round(value),
          ceiling(value)
        )
      )
    )

  formatted_ensemble_predictions <- bind_rows(
    formatted_ensemble_predictions,
    formatted_ensemble_predictions %>%
      filter(format(quantile, digits=3, nsmall=3) == '0.500') %>%
      mutate(
        type='point',
        quantile=NA_real_
      )
  )

  # reformat model weights and eligibility for output
  model_weights <- location_groups %>%
    dplyr::select(-qfm_train, -qfm_test, -y_train, -qra_fit, -qra_forecast) %>%
    tidyr::unnest(locations) %>%
    dplyr::mutate_if(is.logical, as.numeric)
  # model_weights <- purrr::pmap_dfr(
  #   location_groups %>% select(locations, qra_fit),
  #   function(locations, qra_fit) {
  #     temp <- qra_fit$coefficients %>%
  #       tidyr::pivot_wider(names_from = 'model', values_from = 'beta')
  #
  #     return(purrr::map_dfr(
  #       locations,
  #       function(location) {
  #         temp %>%
  #           mutate(location = location)
  #       }
  #     ))
  #   }
  # )
  # model_weights <- bind_cols(
  #   model_weights %>%
  #     select(location) %>%
  #     left_join(fips_codes, by = 'location'),
  #   model_weights %>% select(-location)
  # ) %>%
  #   arrange(location)
  # model_weights[is.na(model_weights)] <- 0.0

  if(response_var == response_vars[1]) {
    all_formatted_ensemble_predictions <- formatted_ensemble_predictions
  } else {
    all_formatted_ensemble_predictions <- bind_rows(
      all_formatted_ensemble_predictions,
      formatted_ensemble_predictions
    )
  }

  message("starting to write out 4-week ensemble data")
  if(final_run) {
  
    message("writing 4-week ensemble forecast file")
    save_dir <- paste0(root, 'data-processed/COVIDhub-4_week_ensemble/')
    if (!file.exists(save_dir)) dir.create(save_dir, recursive = TRUE)
    write_csv(all_formatted_ensemble_predictions %>% select(-location_name),
              paste0(save_dir,
                     formatted_ensemble_predictions$forecast_date[1],
                     '-COVIDhub-4_week_ensemble.csv')
    )

    message("writing 4-week ensemble model eligibility file")
    save_dir <- paste0(root, "4_week_ensemble-metadata/")
    if (!file.exists(save_dir)) dir.create(save_dir, recursive = TRUE)
    write_csv(model_eligibility,
      paste0(save_dir,
        formatted_ensemble_predictions$forecast_date[1],
        '-',
        response_var,
        '-model-eligibility.csv'))

    message("writing 4-week ensemble weights file")
    write_csv(model_weights,
      paste0(save_dir,
        formatted_ensemble_predictions$forecast_date[1],
        '-',
        response_var,
        '-model-weights.csv'))
  }
}

# Check that all models that had any submission for each target are in the
# eligibility metadata file
for (response_var in c("inc_hosp")) {
  if (response_var == "inc_hosp") {
    targets <- paste0(1:28, " day ahead inc hosp")
  } else {
    targets <- paste0(1:4, ' wk ahead ', gsub("_", " ", response_var))
  }
  all_forecasts <- covidHubUtils::load_forecasts(
    dates = forecast_date,
    models = candidate_model_abbreviations_to_include,
    date_window_size = 6,
    targets = targets,
    hub_repo_path = "../covid19-forecast-hub",
    source = "local_hub_repo",
  )
  all_models <- unique(all_forecasts$model)

  save_dir <- paste0(root, "4_week_ensemble-metadata/")
  eligibility <- read_csv(paste0(save_dir,
    forecast_date,
    '-',
    response_var,
    '-model-eligibility.csv'))

  locations <- unique(eligibility$location)

  val_result <- identical(
    sort(paste0(eligibility$location, eligibility$model)),
    tidyr::expand_grid(
      model = all_models,
      location = locations
    ) %>%
    dplyr::mutate(lm = paste0(location, model)) %>%
    dplyr::pull(lm) %>%
    sort()
  )
  message(paste0("CHECK THAT ALL MODELS ARE IN ELIGIBILITY FILE: ", response_var))
  message(val_result)
}

# make plots of ensemble submission
plot_forecasts_single_model(
  submissions_root = paste0(root, "data-processed/"),
  plots_root = plots_root,
  forecast_date = forecast_date,
  model_abbrs = "COVIDhub-4_week_ensemble",
  target_variables = c("deaths", "hospitalizations")
)
reichlab/covidEnsembles documentation built on Jan. 31, 2024, 7:21 p.m.