library(covidData)
library(covidEnsembles)
library(dplyr)
library(tidyr)
library(readr)
library(purrr)
library(zeallot)
library(gridExtra)
library(yaml)
library(reticulate)
Sys.setenv(CUDA_VISIBLE_DEVICES = '-1')
Sys.setenv(LANG = "en_US.UTF-8")
#options(warn=2, error=recover)
#debug(covidEnsembles:::get_by_location_group_ensemble_fits_and_predictions)
#debug(covidEnsembles:::estimate_qra_quantgen)
#debug(covidEnsembles::estimate_qra)
#debug(quantgen:::quantile_ensemble_flex)
# extract arguments specifying details of analysis
#args <- c("local", "inc_death", "2020-11-23", "FALSE", "convex", "mean_impute", "per_quantile", "full_history", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_hosp", "2021-01-18", "FALSE", "convex", "mean_impute", "per_model", "3", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2021-01-25", "FALSE", "convex", "mean_impute", "per_quantile", "constrain", "full_history", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_death", "2021-02-15", "FALSE", "convex", "mean_impute", "per_model", "sort", "6", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_death", "2021-02-15", "FALSE", "convex", "renormalize", "per_model", "sort", "6", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_death", "2021-02-15", "FALSE", "convex_median", "renormalize", "per_model", "sort", "6", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_death", "2020-10-26", "FALSE", "convex", "renormalize", "per_model", "sort", "4", "all_models", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("cluster_single_node", "inc_death", "2021-02-15", "FALSE", "convex", "renormalize", "per_model", "sort", "6", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "cum_death", "2020-06-22", "FALSE", "convex", "mean_impute", "3_groups", "sort", "full_history", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "cum_death", "2020-07-13", "FALSE", "convex", "mean_impute", "3_groups", "sort", "full_history", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "cum_death", "2021-02-15", "FALSE", "convex", "mean_impute", "3_groups", "constrain", "5", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_hosp", "2021-04-26", "FALSE", "convex", "renormalize", "per_model", "sort", "full_history", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_case", "2021-04-26", "FALSE", "convex", "renormalize", "per_model", "sort", "full_history", "TRUE", "FALSE", "FALSE", "county")
#args <- c("local", "inc_hosp", "2021-04-26", "FALSE", "median", "none", "per_model", "sort", "full_history", "5", "TRUE", "FALSE", "FALSE", "state_national")
#args <- c("local", "inc_death", "2021-04-26", "FALSE", "convex_median", "renormalize", "per_model", "sort", "full_history", "5", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2020-10-05", "FALSE", "rel_wis_weighted_median", "none", "per_model", "sort", "4", "5", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2021-03-08", "FALSE", "rel_wis_weighted_median", "none", "per_model", "sort", "full_history", "0", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2021-04-26", "FALSE", "rel_wis_weighted_median", "none", "per_model", "sort", "full_history", "0", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2021-04-26", "FALSE", "mean_weights_weighted_median", "none", "per_model", "sort", "4", "0", "TRUE", "FALSE", "FALSE", "state")
#args <- c("cluster_single_node", "inc_death", "2021-04-05", "FALSE", "rel_wis_weighted_median", "none", "per_model", "sort", "4", "0", "TRUE", "FALSE", "FALSE", "state")
#args <- c("local", "inc_death", "2021-04-05", "FALSE", "rel_wis_weighted_median", "renormalize", "per_model", "sort", "4", "0", "TRUE", "FALSE", "FALSE", "state", "TRUE")
#args <- c("local", "inc_death", "2021-04-05", "FALSE", "rel_wis_weighted_median", "renormalize", "per_quantile", "sort", "4", "0", "TRUE", "FALSE", "FALSE", "state", "FALSE", "all")
#args <- c("cluster_single_node", "inc_death", "2020-07-27", "FALSE", "convex", "renormalize", "per_model", "sort", "12", "0", "TRUE", "FALSE", "FALSE", "state", "FALSE", "all")
args <- commandArgs(trailingOnly = TRUE)
run_setting <- args[1]
if (run_setting %in% c("local", "cluster_single_node")) {
# running locally -- run settings passed as command line arguments
response_var <- args[2]
forecast_date <- lubridate::ymd(args[3])
intercept <- as.logical(args[4])
combine_method <- args[5]
missingness <- args[6]
quantile_group_str <- args[7]
noncross <- args[8]
window_size_arg <- args[9]
top_models_arg <- args[10]
check_missingness_by_target <- as.logical(args[11])
do_standard_checks <- as.logical(args[12])
do_baseline_check <- as.logical(args[13])
spatial_resolution_arg <- args[14]
drop_anomalies <- as.logical(args[15])
horizon_group <- args[16]
if (run_setting == "local") {
# used by covidHubUtils::load_latest_forecasts for loading locally
submissions_root <- "~/research/epi/covid/covid19-forecast-hub/data-processed/"
hub_repo_path <- "~/research/epi/covid/covid19-forecast-hub/"
} else {
submissions_root <- "/project/uma_nicholas_reich/covid19-forecast-hub/data-processed/"
hub_repo_path <- "/project/uma_nicholas_reich/covid19-forecast-hub/"
reticulate::use_python("/usr/bin/python3.8")
}
}
# List of candidate models for inclusion in ensemble
candidate_model_abbreviations_to_include <- get_candidate_models(
submissions_root = submissions_root,
include_designations = c("primary", "secondary"),
include_COVIDhub_ensemble = FALSE,
include_COVIDhub_baseline = TRUE)
# Drop hospitalizations ensemble from JHU APL and ensemble from FDANIHASU
candidate_model_abbreviations_to_include <-
candidate_model_abbreviations_to_include[
!(candidate_model_abbreviations_to_include %in% c("JHUAPL-SLPHospEns", "FDANIHASU-Sweight", "COVIDhub-trained_ensemble", "KITmetricslab-select_ensemble"))
]
if (missingness == "mean_impute") {
case_missingness <- "impute"
missingness <- "impute"
impute_method <- "mean"
} else if (missingness == "renormalize" || missingness == "none") {
case_missingness <- missingness
missingness <- "impute"
impute_method <- "none"
} else {
case_missingness <- missingness
impute_method <- NULL
}
if (response_var %in% c("inc_death", "cum_death")) {
required_quantiles <- c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
if (spatial_resolution_arg == "all") {
spatial_resolution <- c("state", "national")
} else if (spatial_resolution_arg == "state_national") {
spatial_resolution <- c("state", "national")
} else {
spatial_resolution <- spatial_resolution_arg
}
temporal_resolution <- "wk"
if (horizon_group == "all") {
max_horizon <- 4L
targets <- paste0(1:max_horizon, " wk ahead ", gsub("_", " ", response_var))
} else {
max_horizon <- as.integer(horizon_group)
targets <- paste0(max_horizon, " wk ahead ", gsub("_", " ", response_var))
}
forecast_week_end_date <- forecast_date - 2
full_history_start <- lubridate::ymd("2020-06-22") - 7 * 10
} else if (response_var == "inc_case") {
required_quantiles <- c(0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975)
if (spatial_resolution_arg == "all") {
spatial_resolution <- c("county", "state", "national")
} else if (spatial_resolution_arg == "state_national") {
spatial_resolution <- c("state", "national")
} else {
spatial_resolution <- spatial_resolution_arg
}
temporal_resolution <- "wk"
if (horizon_group == "all") {
max_horizon <- 4L
targets <- paste0(1:max_horizon, " wk ahead ", gsub("_", " ", response_var))
} else {
max_horizon <- as.integer(horizon_group)
targets <- paste0(max_horizon, " wk ahead ", gsub("_", " ", response_var))
}
forecast_week_end_date <- forecast_date - 2
full_history_start <- lubridate::ymd("2020-09-14") - 7 * 10
} else if (response_var == "inc_hosp") {
required_quantiles <- c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
if (spatial_resolution_arg == "all") {
spatial_resolution <- c("state", "national")
} else if (spatial_resolution_arg == "state_national") {
spatial_resolution <- c("state", "national")
} else {
spatial_resolution <- spatial_resolution_arg
}
temporal_resolution <- "day"
if (horizon_group == "all") {
max_horizon <- 28L
targets <- paste0(1:(max_horizon + 6), " day ahead ", gsub("_", " ", response_var))
} else {
stop("per horizon estimation is not currently supported for daily targets")
}
forecast_week_end_date <- forecast_date
full_history_start <- lubridate::ymd("2020-11-16") - 7 * 10
}
if (window_size_arg == "full_history") {
window_size <- as.integer((forecast_date - full_history_start) / 7)
} else {
window_size <- as.integer(window_size_arg)
}
if (quantile_group_str == "per_model") {
quantile_groups <- rep(1, length(required_quantiles))
} else if(quantile_group_str == "3_groups") {
if (length(required_quantiles) == 23) {
quantile_groups <- c(rep(1, 4), rep(2, 23 - 8), rep(3, 4))
} else if (length(required_quantiles) == 7) {
quantile_groups <- c(1, rep(2, 5), 3)
}
} else if(quantile_group_str == "per_quantile") {
quantile_groups <- seq_along(required_quantiles)
} else {
stop("invalid quantile_groups")
}
if (spatial_resolution_arg == "all") {
spatial_resolution_path <- ""
} else if (spatial_resolution_arg == "state_national") {
spatial_resolution_path <- "state_national"
} else {
spatial_resolution_path <- spatial_resolution
}
if (top_models_arg == "all_models") {
top_models <- 0L
} else {
top_models <- as.integer(top_models_arg)
}
case_str <- paste0(
# "intercept_", as.character(intercept),
"combine_method_", combine_method,
# "-missingness_", case_missingness,
"-quantile_groups_", quantile_group_str,
# "-noncross_", noncross,
"-window_size_", window_size_arg,
"-top_models_", top_models_arg,
# "-check_missingness_by_target_", check_missingness_by_target,
# "-do_standard_checks_", do_standard_checks,
# "-do_baseline_check_", do_baseline_check
"-drop_anomalies_", drop_anomalies,
"-horizon_group_", horizon_group
)
# create folder where model fits should be saved
fits_dir <- file.path(
"code/application/retrospective-qra-comparison/retrospective-fits",
spatial_resolution_path,
response_var,
case_str)
if (!dir.exists(fits_dir)) {
dir.create(fits_dir)
}
fit_filename <- paste0(
fits_dir, "/",
forecast_date, "-",
case_str, ".rds")
partial_save_filename <- paste0(
fits_dir, "/",
forecast_date, "-",
case_str, ".pkl")
loss_trace_filename <- paste0(
fits_dir, "/",
forecast_date, "-",
case_str, "_loss_trace.rds")
# create folder where model weights should be saved
weights_dir <- file.path(
"code/application/retrospective-qra-comparison/retrospective-weights",
spatial_resolution_path,
response_var,
case_str)
if (!dir.exists(weights_dir)) {
dir.create(weights_dir)
}
weight_filename <- paste0(
weights_dir, "/",
forecast_date, "-",
case_str, ".csv")
# create folder where model forecasts should be saved
forecasts_dir <- file.path(
"code/application/retrospective-qra-comparison/retrospective-forecasts",
spatial_resolution_path,
response_var,
case_str)
if (!dir.exists(forecasts_dir)) {
dir.create(forecasts_dir)
}
forecast_filename <- paste0(
forecasts_dir, "/",
forecast_date, "-",
case_str, ".csv")
# debug(covidEnsembles::estimate_qra)
if (drop_anomalies) {
forecast_date_locations_drop <- NULL
target_end_date_locations_drop <-
readr::read_csv("code/application/retrospective-qra-comparison/data-anomalies/outliers-deaths.csv") %>%
dplyr::filter(issue_date == forecast_date - 1) %>%
dplyr::transmute(
location = location,
target_end_date = as.character(date))
} else {
forecast_date_locations_drop <- NULL
target_end_date_locations_drop <- NULL
}
tic <- Sys.time()
if (!file.exists(forecast_filename)) {
do_q10_check <- do_nondecreasing_quantile_check <- do_standard_checks
tictic <- Sys.time()
results <- build_covid_ensemble(
hub = "US",
source = "zoltar",
# hub_repo_path = hub_repo_path,
candidate_model_abbreviations_to_include =
candidate_model_abbreviations_to_include,
spatial_resolution = spatial_resolution,
targets = targets,
forecast_date = forecast_date,
forecast_week_end_date = forecast_week_end_date,
max_horizon = max_horizon,
timezero_window_size = 6,
window_size = window_size,
data_as_of_date = forecast_date - 1,
forecast_date_locations_drop = forecast_date_locations_drop,
target_end_date_locations_drop = target_end_date_locations_drop,
intercept = intercept,
combine_method = combine_method,
quantile_groups = quantile_groups,
noncross = noncross,
missingness = missingness,
impute_method = impute_method,
backend = ifelse(combine_method == "rel_wis_weighted_median", "grid_search", "qenspy"),
required_quantiles = required_quantiles,
check_missingness_by_target = check_missingness_by_target,
do_q10_check = do_q10_check,
do_nondecreasing_quantile_check = do_nondecreasing_quantile_check,
do_baseline_check = do_baseline_check,
do_sd_check = FALSE,
baseline_tol = 1.0,
top_models = top_models,
manual_eligibility_adjust = NULL,
return_eligibility = TRUE,
return_all = TRUE,
partial_save_frequency = 10,
partial_save_filename = partial_save_filename
)
toctoc <- Sys.time()
toctoc - tictic
# save full results including estimated weights, training data, etc.
# only if running locally; cluster has limited space
if (run_setting == "local") {
saveRDS(results, file = fit_filename)
} else {
saveRDS(results$location_groups$qra_fit[[1]], file = fit_filename)
}
# extract and save just the estimated weights in csv format
if (!(combine_method %in% c("ew", "mean", "median", "rel_wis_weighted_median", "mean_weights_weighted_median"))) {
# save loss trace as a function of optimization iteration
loss_trace <- results$location_groups$qra_fit[[1]]$loss_trace
saveRDS(loss_trace, file = loss_trace_filename)
estimated_weights <- purrr::pmap_dfr(
results$location_groups %>% dplyr::select(locations, qra_fit),
function(locations, qra_fit) {
covidEnsembles:::extract_weights_qenspy_qra_fit(qra_fit) %>%
dplyr::mutate(join_field = "temp") %>%
dplyr::left_join(
data.frame(
location = locations,
join_field = "temp",
stringsAsFactors = FALSE
)
) %>%
dplyr::select(-join_field)
# weights <- qra_fit$coefficients
# data.frame(
# quantile = if ("quantile" %in% colnames(weights)) {
# weights$quantile
# } else {
# rep(NA, nrow(weights))
# },
# model = weights$model,
# weight = weights$beta, #[, 1],
# join_field = "temp",
# stringsAsFactors = FALSE
# ) %>%
# dplyr::left_join(
# data.frame(
# location = locations,
# join_field = "temp",
# stringsAsFactors = FALSE
# )
# ) %>%
# dplyr::select(-join_field)
}
)
write_csv(estimated_weights, weight_filename)
} else if (combine_method == "mean_weights_weighted_median") {
estimated_weights <- purrr::pmap_dfr(
results$location_groups %>% dplyr::select(locations, qra_fit),
function(locations, qra_fit) {
weights <- qra_fit$coefficients
data.frame(
quantile = if ("quantile" %in% colnames(weights)) {
weights$quantile
} else {
rep(NA, nrow(weights))
},
model = weights$model,
weight = weights$beta, #[, 1],
join_field = "temp",
stringsAsFactors = FALSE
) %>%
dplyr::left_join(
data.frame(
location = locations,
join_field = "temp",
stringsAsFactors = FALSE
)
) %>%
dplyr::select(-join_field)
}
)
write_csv(estimated_weights, weight_filename)
}
# save csv formatted forecasts
if (missingness == "impute") {
c(model_eligibility, wide_model_eligibility, location_groups,
weight_transfer, component_forecasts) %<-% results
col_index <- attr(location_groups$imputed_qfm_test[[1]], "col_index")
models_used <- purrr::map_dfc(
unique(col_index$model),
function(model) {
col_ind <- min(which(col_index$model == model))
result <- data.frame(
m = !is.na(unclass(location_groups$imputed_qfm_test[[1]])[, col_ind]))
colnames(result) <- model
return(result)
}
)
model_counts <- apply(
models_used,
1,
sum
)
# locations_to_drop <- unique(
# attr(location_groups$imputed_qfm_test[[1]], "row_index")[
# model_counts == 1, "location"])
ensemble_predictions <- location_groups$qra_forecast[[1]] #%>%
#dplyr::filter(!(location %in% locations_to_drop))
} else {
c(model_eligibility, wide_model_eligibility, location_groups,
component_forecasts) %<-% results
model_counts <- apply(
location_groups %>% select_if(is.logical),
1,
sum)
location_groups <- location_groups[model_counts > 1, ]
if (nrow(location_groups) > 0) {
ensemble_predictions <- bind_rows(location_groups[['qra_forecast']])
}
}
if (nrow(ensemble_predictions) > 0) {
# save the results in required format
formatted_ensemble_predictions <- ensemble_predictions %>%
dplyr::transmute(
forecast_date = forecast_date,
target = target,
target_end_date = covidHubUtils::calc_target_end_date(
forecast_date,
as.integer(substr(target, 1, regexpr(" ", target, fixed = TRUE) - 1)),
rep(temporal_resolution, nrow(ensemble_predictions))),
location = location,
type = 'quantile',
quantile = quantile,
value = ifelse(
quantile < 0.5,
floor(value),
ifelse(
quantile == 0.5,
round(value),
ceiling(value)
)
)
)
formatted_ensemble_predictions <- bind_rows(
formatted_ensemble_predictions,
formatted_ensemble_predictions %>%
filter(format(quantile, digits = 3, nsmall = 3) == "0.500") %>%
mutate(
type = "point",
quantile = NA_real_
)
)
write_csv(formatted_ensemble_predictions, forecast_filename)
}
}
toc <- Sys.time()
toc - tic
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.