tests/testthat/test-fair-aaa.R

test_that("basic functionality", {
  data("hpc_cv")

  silly <- function(...) {metric_tibbler("a", "b", 1)}
  silly_metric <- new_class_metric(silly, "minimize")
  silly_new_groupwise_metric <-
    new_groupwise_metric(
      silly_metric,
      "silly",
      aggregate = function(x, ...) {x$.estimate[1]}
    )

  expect_true(inherits(silly_new_groupwise_metric, "function"))

  silly_fairness_Resample <- silly_new_groupwise_metric(Resample)
  expect_equal(attr(silly_fairness_Resample, "by"), "Resample")
  expect_s3_class(silly_fairness_Resample, "class_metric")

  expect_silent(
    m_set <- metric_set(silly_fairness_Resample)
  )

  expect_s3_class(m_set, "class_prob_metric_set")
  expect_equal(
    as_tibble(m_set),
    tibble::tibble(metric = "silly_fairness_Resample",
                   class = "class_metric",
                   direction = "minimize")
  )

  expect_equal(
    m_set(hpc_cv, truth = obs, estimate = pred),
    tibble::tibble(.metric = "silly",
                   .by = "Resample",
                   .estimator = "b",
                   ".estimate" = 1)
  )
})

test_that("new_groupwise_metric() works with grouped input", {
  data("hpc_cv")
  hpc_cv$group <- sample(c("a", "b"), nrow(hpc_cv), replace = TRUE)

  expect_silent(
    m_set <-
      metric_set(
        equal_opportunity(group)
      )
  )

  grouped_res <-
    hpc_cv %>%
    dplyr::group_by(Resample) %>%
    m_set(hpc_cv, truth = obs, estimate = pred)

  hpc_split <- vctrs::vec_split(hpc_cv, hpc_cv$Resample)

  split_res <- dplyr::tibble(
    Resample = hpc_split$key,
    res = lapply(
      hpc_split$val,
      function(x) m_set(x, truth = obs, estimate = pred))
  ) %>%
    tidyr::unnest(cols = res)

  expect_identical(grouped_res, split_res)
})

test_that("metric factory print method works", {
  expect_snapshot(equal_opportunity)
})

test_that("can accommodate redundant sensitive features", {
  data("hpc_cv")

  expect_silent(
    m_set <-
      metric_set(
        demographic_parity(Resample),
        equal_opportunity(Resample)
      )
  )

  expect_s3_class(m_set, "class_prob_metric_set")

  res <- m_set(hpc_cv, truth = obs, estimate = pred)

  expect_equal(res$.metric, c("demographic_parity", "equal_opportunity"))
  expect_equal(res$.by, c("Resample", "Resample"))
})

test_that("can accommodate multiple sensitive features", {
  data("hpc_cv")

  hpc_cv$group <- sample(c("a", "b"), nrow(hpc_cv), replace = TRUE)

  expect_silent(
    m_set <-
      metric_set(
        demographic_parity(Resample),
        equal_opportunity(group)
      )
  )

  expect_s3_class(m_set, "class_prob_metric_set")

  res <- m_set(hpc_cv, truth = obs, estimate = pred)

  expect_equal(res$.metric, c("demographic_parity", "equal_opportunity"))
  expect_equal(res$.by, c("Resample", "group"))
})

test_that("can mix fairness metrics with standard metrics", {
  data("hpc_cv")

  expect_silent(
    m_set <-
      metric_set(
        demographic_parity(Resample),
        sens
      )
  )

  expect_s3_class(m_set, "class_prob_metric_set")

  res <- m_set(hpc_cv, truth = obs, estimate = pred)

  expect_equal(res$.metric, c("demographic_parity", "sens"))
  expect_equal(res$.by, c("Resample", NA_character_))

  m_set_sens <- metric_set(sens)
  expect_equal(
    res[2,] %>% dplyr::select(-.by),
    m_set_sens(hpc_cv, truth = obs, estimate = pred)
  )
})

test_that("can handle metric set input as `fn`", {
  data("hpc_cv")

  expect_silent(
    fairness_mtrc <-
      new_groupwise_metric(
        metric_set(sens, spec),
        "min_sens_spec",
        diff_range
      )
  )

  expect_s3_class(fairness_mtrc(Resample), "class_metric")
  expect_equal(attr(fairness_mtrc(Resample), "by"), "Resample")
  expect_equal(attr(fairness_mtrc(Resample), "direction"), "minimize")

  res <- fairness_mtrc(Resample)(hpc_cv, truth = obs, estimate = pred)

  expect_equal(res$.metric, "min_sens_spec")
  expect_equal(res$.by, "Resample")
})

test_that("handles `direction` input", {
  expect_silent(
    fairness_mtrc <-
      new_groupwise_metric(
        sens,
        "inv_sens_diff",
        function(x) {1/diff(range(x$.estimate))},
        direction = "maximize"
      )
  )

  expect_equal(attr(fairness_mtrc(Resample), "direction"), "maximize")

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(
      sens,
      "bad_direction",
      diff_range,
      direction = "boop"
    )
  )
})

test_that("errors informatively with bad input", {
  expect_snapshot(
    error = TRUE,
    new_groupwise_metric()
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(sens)
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(sens, "bop")
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric("boop", "bop", identity)
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(identity, "bop", identity)
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(sens, 1, identity)
  )

  expect_snapshot(
    error = TRUE,
    new_groupwise_metric(sens, "bop", "boop")
  )
})

test_that("outputted function errors informatively with bad input", {
  data("hpc_cv")

  bad_aggregate <- new_groupwise_metric(sens, "bop", identity)
  expect_snapshot(
    error = TRUE,
    bad_aggregate(Resample)(hpc_cv, truth = obs, estimate = pred)
  )

  bad_by <- new_groupwise_metric(sens, "bop", identity)(nonexistent_column)
  expect_snapshot(
    error = TRUE,
    bad_by(hpc_cv, truth = obs, estimate = pred)
  )

  bad_truth_metric_set <-
    new_groupwise_metric(metric_set(sens, spec), "bop", function(x) {1})(Resample)
  expect_snapshot(
    error = TRUE,
    bad_truth_metric_set(hpc_cv, truth = VF, estimate = pred)
  )

  bad_truth_metric <-
    new_groupwise_metric(sens, "bop", function(x) {1})(Resample)
  expect_snapshot(
    error = TRUE,
    bad_truth_metric(hpc_cv, truth = VF, estimate = pred)
  )
})

test_that("outputted function errors informatively with redundant grouping", {
  data("hpc_cv")

  expect_snapshot(
    error = TRUE,
    hpc_cv %>%
      dplyr::group_by(Resample) %>%
      demographic_parity(Resample)(truth = obs, estimate = pred)
  )

  dp_res <- demographic_parity(Resample)

  expect_snapshot(
    error = TRUE,
    hpc_cv %>%
      dplyr::group_by(Resample) %>%
      dp_res(truth = obs, estimate = pred)
  )
})

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on June 22, 2024, 7:07 p.m.