# These tests address #290:
# https://github.com/cmu-delphi/epipredict/issues/290
library(dplyr)
train <- epidatasets::cases_deaths_subset |>
filter(time_value >= as.Date("2021-10-01")) |>
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
ngeos <- n_distinct(train$geo_value)
test_that("flatline determines target_date where forecast_date exists", {
flat <- flatline_forecaster(
train, "dr",
args_list = flatline_args_list(
forecast_date = as.Date("2021-12-31"),
target_date = as.Date("2022-01-01"),
ahead = 1L
)
)
# previously, if target_date existed, it could be
# erroneously incremented by the ahead
expect_identical(
flat$predictions$target_date,
rep(as.Date("2022-01-01"), ngeos)
)
expect_identical(
flat$predictions$forecast_date,
rep(as.Date("2021-12-31"), ngeos)
)
expect_true(all(!is.na(flat$predictions$.pred_distn)))
expect_true(all(!is.na(flat$predictions$.pred)))
})
test_that("arx_forecaster determines target_date where forecast_date exists", {
arx <- arx_forecaster(
train, "dr", c("dr", "cr"),
args_list = arx_args_list(
forecast_date = as.Date("2021-12-31"),
target_date = as.Date("2022-01-01"),
ahead = 1L
)
)
# previously, if target_date existed, it could be
# erroneously incremented by the ahead
expect_identical(
arx$predictions$target_date,
rep(as.Date("2022-01-01"), ngeos)
)
expect_identical(
arx$predictions$forecast_date,
rep(as.Date("2021-12-31"), ngeos)
)
expect_true(all(!is.na(arx$predictions$.pred_distn)))
expect_true(all(!is.na(arx$predictions$.pred)))
})
test_that("arx_classifier determines target_date where forecast_date exists", {
arx <- arx_classifier(
train, "dr", c("dr"),
trainer = parsnip::boost_tree(mode = "classification", trees = 5),
args_list = arx_class_args_list(
forecast_date = as.Date("2021-12-31"),
target_date = as.Date("2022-01-01"),
ahead = 1L
)
)
# previously, if target_date existed, it could be
# erroneously incremented by the ahead
expect_identical(
arx$predictions$target_date,
rep(as.Date("2022-01-01"), ngeos)
)
expect_identical(
arx$predictions$forecast_date,
rep(as.Date("2021-12-31"), ngeos)
)
expect_true(all(!is.na(arx$predictions$.pred_class)))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.