tests/testthat/test-simple_ensemble.R

# set up simple data for test cases
model_outputs <- expand.grid(stringsAsFactors = FALSE,
                             model_id = letters[1:4],
                             location = c("222", "888"),
                             horizon = 1, #week
                             target = "inc death",
                             target_date = as.Date("2021-12-25"),
                             output_type = "quantile",
                             output_type_id = c(.1, .5, .9),
                             value = NA_real_)

v2_1 <- model_outputs$value[model_outputs$location == "222" &
                              model_outputs$output_type_id == .1] <-
  c(10, 30, 15, 20)
v2_5 <- model_outputs$value[model_outputs$location == "222" &
                              model_outputs$output_type_id == .5] <-
  c(40, 40, 45, 50)
v2_9 <- model_outputs$value[model_outputs$location == "222" &
                              model_outputs$output_type_id == .9] <-
  c(60, 70, 75, 80)
v8_1 <- model_outputs$value[model_outputs$location == "888" &
                              model_outputs$output_type_id == .1] <-
  c(100, 300, 400, 250)
v8_5 <- model_outputs$value[model_outputs$location == "888" &
                              model_outputs$output_type_id == .5] <-
  c(150, 325, 500, 300)
v8_9 <- model_outputs$value[model_outputs$location == "888" &
                              model_outputs$output_type_id == .9] <-
  c(250, 350, 500, 350)

fweight2 <- data.frame(model_id = letters[1:4],
                       location = "222",
                       weight = 0.1 * (1:4))
fweight8 <- data.frame(model_id = letters[1:4],
                       location = "888",
                       weight = 0.1 * (4:1))
fweight <- dplyr::bind_rows(fweight2, fweight8)


test_that("non-default columns are dropped from output", {
  output_names <- model_outputs |>
    dplyr::mutate(extra_col_1 = "a", extra_col_2 = "a") |>
    simple_ensemble(
      task_id_cols = c("target_date", "target", "horizon", "location")
    ) |>
    names()

  expect_equal(sort(names(model_outputs)), sort(output_names))
})


test_that("invalid method argument throws error", {
  expect_error(
    simple_ensemble(model_outputs, agg_fun = "linear pool")
  )
})


test_that("component model outputs and resulting ensemble model outputs have identical sorted unique output type ids", {
  quantile_outputs <- expand.grid(stringsAsFactors = FALSE,
                                  model_id = letters[1:4],
                                  location = c("222", "888"),
                                  horizon = 1, #week
                                  target = "inc death",
                                  target_date = as.Date("2021-12-25"),
                                  output_type = "quantile",
                                  output_type_id = c(.1, .5, .9),
                                  value = NA_real_)

  quantile_outputs$value[quantile_outputs$location == "222" &
                           quantile_outputs$output_type_id == .1] <-
    c(10, 30, 15, 20)
  quantile_outputs$value[quantile_outputs$location == "222" &
                           quantile_outputs$output_type_id == .5] <-
    c(40, 40, 45, 50)
  quantile_outputs$value[quantile_outputs$location == "222" &
                           quantile_outputs$output_type_id == .9] <-
    c(60, 70, 75, 80)
  quantile_outputs$value[quantile_outputs$location == "888" &
                           quantile_outputs$output_type_id == .1] <-
    c(100, 300, 400, 250)
  quantile_outputs$value[quantile_outputs$location == "888" &
                           quantile_outputs$output_type_id == .5] <-
    c(150, 325, 500, 300)
  quantile_outputs$value[quantile_outputs$location == "888" &
                           quantile_outputs$output_type_id == .9] <-
    c(250, 350, 500, 350)

  expected_output_type_ids <- sort(unique(quantile_outputs$output_type_id))

  actual_output_type_ids <- quantile_outputs |>
    simple_ensemble(weights = NULL,
                    weights_col_name = NULL,
                    model_id = "hub-ensemble",
                    task_id_cols = NULL) |>
    dplyr::pull("output_type_id") |>
    unique() |>
    sort()

  expect_equal(expected_output_type_ids, actual_output_type_ids)
})

test_that("group_by(output_type_id) produces expected results", {
  explicit_outputs <- expand.grid(stringsAsFactors = FALSE,
                                  model_id = letters[1:4],
                                  location = c("222", "888"),
                                  horizon = 1, #week
                                  target = "inc death",
                                  target_date = as.Date("2021-12-25"),
                                  output_type = "quantile",
                                  output_type_id = c(.025, .1, .25, .75, .9, .975),
                                  value = NA_real_)

  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .025] <-
    c(4, 12, 6, 8)
  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .100] <-
    c(10, 30, 15, 20)
  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .250] <-
    c(20, 40, 25, 30)
  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .750] <-
    c(50, 50, 55, 60)
  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .900] <-
    c(60, 70, 75, 80)
  explicit_outputs$value[explicit_outputs$location == "222" &
                           explicit_outputs$output_type_id == .975] <-
    c(70, 80, 85, 90)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .025] <-
    c(40, 120, 160, 100)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .100] <-
    c(100, 300, 400, 250)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .250] <-
    c(150, 325, 475, 300)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .750] <-
    c(200, 325, 500, 325)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .900] <-
    c(250, 350, 500, 350)
  explicit_outputs$value[explicit_outputs$location == "888" &
                           explicit_outputs$output_type_id == .975] <-
    c(350, 450, 550, 450)

  intervals <- c(.50, .80, .95)
  implicit_outputs <- explicit_outputs
  implicit_outputs$output_type_id <- sort(rep(c((1 - intervals) / 2, 1 - (1 - intervals) / 2), 8))

  explicit_ensemble <- simple_ensemble(explicit_outputs, weights = NULL,
                                       weights_col_name = NULL,
                                       model_id = "hub-ensemble",
                                       task_id_cols = NULL)

  implicit_ensemble <- simple_ensemble(implicit_outputs, weights = NULL,
                                       weights_col_name = NULL,
                                       model_id = "hub-ensemble",
                                       task_id_cols = NULL)

  expect_equal(explicit_ensemble, implicit_ensemble)
})


test_that("(weighted) medians and means correctly calculated", {
  median_expected <- mean_expected <-
    weighted_median_expected <- weighted_mean_expected <-
    data.frame(model_id = "hub-ensemble",
               location = rep(c("222", "888"), each = 3),
               horizon = 1,
               target = "inc death",
               target_date = as.Date("2021-12-25"),
               output_type = "quantile",
               output_type_id = rep(c(.1, .5, .9), 2),
               value = NA_real_)

  median_vals <- sapply(list(v2_1, v2_5, v2_9, v8_1, v8_5, v8_9), median)
  mean_vals <- sapply(list(v2_1, v2_5, v2_9, v8_1, v8_5, v8_9), mean)

  weighted_median_vals <- c(sapply(list(v2_1, v2_5, v2_9),
                                   matrixStats::weightedMedian,
                                   w = fweight2$weight,
                                   interpolate = FALSE),
                            sapply(list(v8_1, v8_5, v8_9),
                                   matrixStats::weightedMedian,
                                   w = fweight8$weight,
                                   interpolate = FALSE))
  weighted_mean_vals <- c(sapply(list(v2_1, v2_5, v2_9),
                                 matrixStats::weightedMean,
                                 w = fweight2$weight),
                          sapply(list(v8_1, v8_5, v8_9),
                                 matrixStats::weightedMean,
                                 w = fweight8$weight))

  median_expected$value <- median_vals
  mean_expected$value <- mean_vals
  weighted_mean_expected$value <- weighted_mean_vals
  weighted_median_expected$value <- weighted_median_vals

  median_actual <- simple_ensemble(model_out_tbl = model_outputs,
                                   weights = NULL,
                                   agg_fun = "median")
  mean_actual <- simple_ensemble(model_out_tbl = model_outputs,
                                 weights = NULL,
                                 agg_fun = "mean")

  weighted_median_actual <- simple_ensemble(model_out_tbl = model_outputs,
                                            weights = fweight,
                                            agg_fun = "median")
  weighted_mean_actual <- simple_ensemble(model_out_tbl = model_outputs,
                                          weights = fweight,
                                          agg_fun = "mean")

  expect_equal(as.data.frame(median_actual), median_expected)
  expect_equal(as.data.frame(mean_actual), mean_expected)

  expect_equal(as.data.frame(weighted_median_actual), weighted_median_expected)
  expect_equal(as.data.frame(weighted_mean_actual), weighted_mean_expected)
})


test_that("(weighted) medians and means work with alternate name for weights columns", {
  weighted_median_actual <- model_outputs |>
    simple_ensemble(weights = fweight |> dplyr::rename(w = "weight"),
                    weights_col_name = "w",
                    agg_fun = "median")
  weighted_mean_actual <- model_outputs |>
    simple_ensemble(weights = fweight |> dplyr::rename(w = "weight"),
                    weights_col_name = "w",
                    agg_fun = "mean")

  weighted_median_expected <- simple_ensemble(model_out_tbl = model_outputs,
                                              weights = fweight,
                                              agg_fun = "median")
  weighted_mean_expected <- simple_ensemble(model_out_tbl = model_outputs,
                                            weights = fweight,
                                            agg_fun = "mean")

  expect_equal(weighted_mean_actual, weighted_mean_expected)
  expect_equal(weighted_median_actual, weighted_median_expected)
})

test_that("passing agg_fun as symbol or character results in identical behaviour", {

  weighted_mean_actual_char <- simple_ensemble(model_out_tbl = model_outputs,
                                               weights = fweight,
                                               agg_fun = "mean")
  weighted_mean_actual_symbol <- simple_ensemble(model_out_tbl = model_outputs,
                                                 weights = fweight,
                                                 agg_fun = mean)

  expect_identical(weighted_mean_actual_char, weighted_mean_actual_symbol)
})

test_that("duplicate forecast values still result in correct weighted median", {
  toy_outputs <- data.frame(
    model_id = letters[1:6],
    output_type = rep("quantile", 6),
    output_type_id = rep(0.5, 6),
    value = c(-0.103, -0.89, 0, 0, 0.039, 0.055)
  )
  weights <- data.frame(
    model_id = letters[1:6],
    weight = c(0.08, 0.14, 0.22, 0.12, 0.28, 0.16)
  )

  weighted_median_expected <- data.frame(
    model_id = "hub-ensemble",
    output_type = "quantile",
    output_type_id = 0.5,
    value = 0
  ) |>
    hubUtils::as_model_out_tbl()

  weighted_median_actual <- toy_outputs |>
    simple_ensemble(weights, agg_fun = "median")

  expect_equal(weighted_median_expected, weighted_median_actual)
})

test_that("simple_ensemble accepts custom functions without error", {
  geometric_mean <- function(x) {
    n <- length(x)
    prod(x)^(1 / n)
  }

  model_outputs |>
    hubEnsembles::simple_ensemble(
      agg_fun = geometric_mean,
      model_id = "simple-ensemble-geometric"
    ) |>
    expect_no_error()
})

Try the hubEnsembles package in your browser

Any scripts or data that you put into this service are public.

hubEnsembles documentation built on June 8, 2025, 1:26 p.m.