tests/testthat/test-case-weights.R

# ------------------------------------------------------------------------------
# .use_case_weights_with_yardstick()

test_that("knows about importance weights", {
  x <- hardhat::importance_weights(1)
  expect_false(.use_case_weights_with_yardstick(x))
})

test_that("knows about frequency weights", {
  x <- hardhat::frequency_weights(1)
  expect_true(.use_case_weights_with_yardstick(x))
})

test_that("gives informative default error", {
  expect_snapshot(error = TRUE, {
    .use_case_weights_with_yardstick(1)
  })
})

# ------------------------------------------------------------------------------
# extract_case_weights()

test_that("`extract_case_weights()` errors if `col` doesn't exist", {
  skip_if(packageVersion("rlang") < "1.0.6.9000")
  wf <- workflows::workflow()

  expect_snapshot(error = TRUE, {
    extract_case_weights(mtcars, wf)
  })
})

test_that("`extract_case_weights()` errors if case weights column isn't the right class", {
  mtcars$weight <- 1L

  wf <- workflows::workflow()
  wf <- workflows::add_case_weights(wf, weight)

  expect_snapshot(error = TRUE, {
    extract_case_weights(mtcars, wf)
  })
})

test_that("`extract_case_weights()` works", {
  set.seed(1)

  mtcars$weight <- hardhat::frequency_weights(1:32)

  folds <- rsample::vfold_cv(mtcars, v = 2)

  spec <- parsnip::linear_reg()
  metrics <- yardstick::metric_set(yardstick::rmse)

  wf <- workflows::workflow()
  wf <- workflows::add_variables(wf, mpg, c(disp, cyl))
  wf <- workflows::add_model(wf, spec)
  wf <- workflows::add_case_weights(wf, weight)

  extracted <- tune:::extract_case_weights(mtcars, wf)
  expted <- mtcars |>
    tibble::as_tibble() |>
    dplyr::select(.case_weights = weight)
  expect_equal(extracted, expted)
})

# ------------------------------------------------------------------------------
# Passed on during tuning

test_that("weights are used during resampling", {
  set.seed(1)

  mtcars$weight <- hardhat::frequency_weights(1:32)

  folds <- rsample::vfold_cv(mtcars, v = 2)

  spec <- parsnip::linear_reg()
  metrics <- yardstick::metric_set(yardstick::rmse)

  wf <- workflows::workflow()
  wf <- workflows::add_variables(wf, mpg, c(disp, cyl))
  wf <- workflows::add_model(wf, spec)
  wf <- workflows::add_case_weights(wf, weight)

  res <- fit_resamples(
    object = wf,
    resamples = folds,
    metrics = metrics,
    control = control_resamples(save_pred = TRUE)
  )

  predictions <- res$.predictions[[1]]

  manual_with_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred, case_weight = .case_weights)

  manual_with_no_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred)

  res_metric <- res$.metrics[[1]] |> dplyr::select(-.config)

  expect_equal(res_metric, manual_with_wts)
  expect_true(!identical(res_metric, manual_with_no_wts))
})


test_that("weights are used during tuning", {
  skip_if_not_installed("splines2")
  set.seed(1)

  mtcars$weight <- hardhat::frequency_weights(1:32)

  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <-
    recipes::recipe(mpg ~ ., data = mtcars) |>
    recipes::step_spline_natural(disp, deg_free = tune())

  spec <- parsnip::linear_reg()
  metrics <- yardstick::metric_set(yardstick::rmse)

  wf <- workflows::workflow()
  wf <- workflows::add_recipe(wf, rec)
  wf <- workflows::add_model(wf, spec)
  wf <- workflows::add_case_weights(wf, weight)

  df_grid <- tibble(deg_free = 3:4)

  res <- tune_grid(
    object = wf,
    resamples = folds,
    metrics = metrics,
    grid = df_grid,
    control = control_resamples(save_pred = TRUE)
  )

  predictions <- res$.predictions[[1]] |> group_by(.config)

  manual_with_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred, case_weight = .case_weights)

  manual_with_no_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred)

  res_metric <- res$.metrics[[1]] |>
    dplyr::select(all_of(names(manual_with_wts)))

  expect_equal(res_metric, manual_with_wts)
  expect_true(!identical(res_metric, manual_with_no_wts))
})


test_that("weights work with multi-predict", {
  # glmnet depends on >= 3.6.0 so we don't test on CRAN
  skip_if_not_installed("glmnet")

  set.seed(1)

  mtcars$weight <- hardhat::frequency_weights(1:32)

  folds <- rsample::vfold_cv(mtcars, v = 2)

  spec <- parsnip::linear_reg(penalty = tune(), mixture = tune())
  spec <- parsnip::set_engine(spec, "glmnet")

  grid <- expand.grid(penalty = 1:3, mixture = (1:5) / 5)

  metrics <- yardstick::metric_set(yardstick::rmse)

  wf <- workflows::workflow()
  wf <- workflows::add_variables(wf, mpg, c(disp, cyl))
  wf <- workflows::add_model(wf, spec)
  wf <- workflows::add_case_weights(wf, weight)

  res <- tune_grid(
    object = wf,
    resamples = folds,
    grid = grid,
    control = control_grid(save_pred = TRUE),
    metrics = metrics
  )

  predictions <- res$.predictions[[1]] |> group_by(.config)

  manual_with_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred, case_weight = .case_weights)

  manual_with_no_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred)

  res_metric <- res$.metrics[[1]] |>
    dplyr::select(all_of(names(manual_with_wts)))

  expect_equal(res_metric, manual_with_wts)
  expect_true(!identical(res_metric, manual_with_no_wts))
})

test_that("importance weights are *not* used during prediction", {
  skip_if_not_installed("splines2")
  set.seed(1)

  mtcars$weight <- hardhat::importance_weights((1:32) / 32)

  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <-
    recipes::recipe(mpg ~ ., data = mtcars) |>
    recipes::step_spline_natural(disp, deg_free = tune())

  spec <- parsnip::linear_reg()
  metrics <- yardstick::metric_set(yardstick::rmse)

  wf <- workflows::workflow()
  wf <- workflows::add_recipe(wf, rec)
  wf <- workflows::add_model(wf, spec)
  wf <- workflows::add_case_weights(wf, weight)

  df_grid <- tibble(deg_free = 3:4)

  res <- tune_grid(
    object = wf,
    resamples = folds,
    metrics = metrics,
    grid = df_grid,
    control = control_resamples(save_pred = TRUE)
  )

  predictions <- res$.predictions[[1]] |> group_by(.config)

  expect_true(!any(names(predictions) == ".case_weights"))

  manual_with_no_wts <-
    predictions |>
    yardstick::rmse(mpg, .pred)

  res_metric <- res$.metrics[[1]] |>
    dplyr::select(all_of(names(manual_with_no_wts)))

  expect_equal(res_metric, manual_with_no_wts)
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.