tests/testthat/test-weights.R

# Test file for variable fold weights functionality

# Setup test data
set.seed(42)
test_data <- data.frame(
  x1 = rnorm(50),
  x2 = rnorm(50),
  x3 = rnorm(50)
)
test_data$y <- 2 * test_data$x1 + 3 * test_data$x2 + rnorm(50, sd = 0.5)

set.seed(123)
folds <- rsample::vfold_cv(mtcars, v = 3)

# Helper function to create a simple model
create_test_model <- function() {
  parsnip::linear_reg() |> parsnip::set_engine("lm")
}


test_that("add_resample_weights() validates inputs correctly", {
  expect_snapshot(
    add_resample_weights("not_an_rset", c(0.5, 0.3, 0.2)),
    error = TRUE
  )

  expect_snapshot(
    add_resample_weights(folds, c("a", "b", "c")),
    error = TRUE
  )

  expect_snapshot(
    add_resample_weights(folds, c(0.5, 0.3)),
    error = TRUE
  )

  expect_snapshot(
    add_resample_weights(folds, c(-0.1, 0.5, 0.6)),
    error = TRUE
  )

  expect_snapshot(
    add_resample_weights(folds, c(0, 0, 0)),
    error = TRUE
  )
})

test_that("add_resample_weights() adds weights correctly", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  # Weights get normalized to sum to 1
  expected_weights <- weights / sum(weights)

  expect_s3_class(weighted_folds, "rset")
  expect_equal(attr(weighted_folds, ".resample_weights"), expected_weights)
  expect_equal(nrow(weighted_folds), nrow(folds))
})

test_that("calculate_resample_weights() works correctly", {
  auto_weights <- calculate_resample_weights(folds)

  expect_type(auto_weights, "double")
  expect_length(auto_weights, nrow(folds))
  expect_true(all(auto_weights > 0))
  expect_true(abs(sum(auto_weights) - 1) < 1e-10)
})

test_that("weights are preserved through tuning pipeline", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  mod <- create_test_model()

  suppressWarnings({
    res <- tune_grid(
      mod,
      mpg ~ .,
      resamples = weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  metrics <- collect_metrics(res)
  expect_equal(nrow(metrics), 1)
  expect_true("mean" %in% names(metrics))
  expect_true(is.numeric(metrics$mean))
})

test_that("weights affect metric aggregation", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  mod <- create_test_model()

  suppressWarnings({
    # Unweighted results
    res_unweighted <- tune_grid(
      mod,
      mpg ~ .,
      resamples = folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )

    # Weighted results
    res_weighted <- tune_grid(
      mod,
      mpg ~ .,
      resamples = weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  unweighted_rmse <- collect_metrics(res_unweighted)$mean[1]
  weighted_rmse <- collect_metrics(res_weighted)$mean[1]

  expect_true(is.numeric(unweighted_rmse))
  expect_true(is.numeric(weighted_rmse))
  expect_false(is.na(unweighted_rmse))
  expect_false(is.na(weighted_rmse))
})

test_that("extreme weights show larger effect", {
  skip_if_not_installed("kknn")

  # Create folds for this specific test
  set.seed(42)
  test_folds <- rsample::vfold_cv(test_data, v = 3)

  # Regular weights
  weights <- c(0.6, 0.2, 0.2)
  weighted_folds <- add_resample_weights(test_folds, weights)

  # Extreme weights
  extreme_weights <- c(0.95, 0.025, 0.025)
  extreme_weighted_folds <- add_resample_weights(test_folds, extreme_weights)

  # Create a model with tuning parameter
  knn_spec <- parsnip::nearest_neighbor(neighbors = tune()) |>
    parsnip::set_engine("kknn") |>
    parsnip::set_mode("regression")

  param_grid <- data.frame(neighbors = c(3, 5))

  suppressWarnings({
    # Unweighted
    res_unweighted <- tune_grid(
      knn_spec,
      y ~ .,
      resamples = test_folds,
      grid = param_grid,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )

    # Regular weights
    res_weighted <- tune_grid(
      knn_spec,
      y ~ .,
      resamples = weighted_folds,
      grid = param_grid,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )

    # Extreme weights
    res_extreme <- tune_grid(
      knn_spec,
      y ~ .,
      resamples = extreme_weighted_folds,
      grid = param_grid,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  unweighted_metrics <- collect_metrics(res_unweighted)
  weighted_metrics <- collect_metrics(res_weighted)
  extreme_metrics <- collect_metrics(res_extreme)

  # Check that results exist and are sensible
  expect_equal(nrow(unweighted_metrics), 2)
  expect_equal(nrow(weighted_metrics), 2)
  expect_equal(nrow(extreme_metrics), 2)

  # Calculate differences
  regular_diff <- max(abs(unweighted_metrics$mean - weighted_metrics$mean))
  extreme_diff <- max(abs(unweighted_metrics$mean - extreme_metrics$mean))

  expect_true(regular_diff >= 0)
  expect_true(extreme_diff >= 0)
  expect_true(all(is.finite(c(regular_diff, extreme_diff))))
})

test_that("weight normalization works correctly", {
  expect_equal(
    tune:::.validate_resample_weights(c(3, 6, 9), 3),
    c(1 / 6, 1 / 3, 1 / 2) # normalized to sum to 1
  )

  expect_equal(
    tune:::.validate_resample_weights(c(0.2, 0.3, 0.5), 3),
    c(0.2, 0.3, 0.5) # already normalized to sum to 1
  )
})

test_that("equal weights return NULL", {
  # Simplest integer match
  expect_null(tune:::.validate_resample_weights(c(2, 2, 2), 3))

  # Fractional match
  expect_null(tune:::.validate_resample_weights(c(1 / 3, 1 / 3, 1 / 3), 3))

  # Check more reseampless
  expect_null(tune:::.validate_resample_weights(c(1, 1, 1, 1, 1), 5))
})

test_that("unequal weights do not return NULL", {
  # Check non-null decimal values
  result <- tune:::.validate_resample_weights(c(0.1, 0.5, 0.4), 3)
  expect_false(is.null(result))
  expect_equal(result, c(0.1, 0.5, 0.4))

  # Non-null fractional values
  result2 <- tune:::.validate_resample_weights(c(1, 2, 3), 3)
  expect_false(is.null(result2))
  expect_equal(result2, c(1 / 6, 2 / 6, 3 / 6))
})

test_that("add_resample_weights with equal weights returns NULL attribute", {
  # Adding equal weights should trigger NULL assignment
  equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1))
  expect_null(attr(equal_weighted_folds, ".resample_weights"))

  # Verify it's still an rset object
  expect_s3_class(equal_weighted_folds, "rset")
})

test_that("equal weights produce same results as no weights", {
  mod <- create_test_model()

  suppressWarnings({
    # Results with no weights
    res_no_weights <- tune_grid(
      mod,
      mpg ~ .,
      resamples = folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )

    # Results with equal weights
    equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1))
    res_equal_weights <- tune_grid(
      mod,
      mpg ~ .,
      resamples = equal_weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  metrics_no_weights <- collect_metrics(res_no_weights)
  metrics_equal_weights <- collect_metrics(res_equal_weights)

  # Results should match
  expect_equal(metrics_no_weights$mean, metrics_equal_weights$mean)
  expect_equal(metrics_no_weights$std_err, metrics_equal_weights$std_err)
})

test_that("weighted statistics functions work correctly", {
  x <- c(1, 2, 3, 4, 5)
  w <- c(0.1, 0.2, 0.3, 0.2, 0.2)

  weighted_sd <- tune:::.weighted_sd(x, w)

  expect_true(is.numeric(weighted_sd))
  expect_false(is.na(weighted_sd))
  expect_true(weighted_sd >= 0)

  # Test with NA values
  x_na <- c(1, 2, NA, 4, 5)
  weighted_sd_na <- tune:::.weighted_sd(x_na[!is.na(x_na)], w[!is.na(x_na)])

  expect_true(is.numeric(weighted_sd_na))

  # Test edge cases
  expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value
})

test_that("fold weight extraction works", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  # Weights get normalized to sum to 1
  expected_weights <- weights / sum(weights)

  mod <- create_test_model()

  suppressWarnings({
    res <- tune_grid(
      mod,
      mpg ~ .,
      resamples = weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  extracted_weights <- tune:::.get_resample_weights(res)
  expect_equal(extracted_weights, expected_weights)
})

test_that("individual fold metrics can be collected", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  mod <- create_test_model()

  suppressWarnings({
    res <- tune_grid(
      mod,
      mpg ~ .,
      resamples = weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  # Collect individual fold metrics
  individual_metrics <- collect_metrics(res, summarize = FALSE)

  expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold
  expect_true("id" %in% names(individual_metrics))
  expect_true(".estimate" %in% names(individual_metrics))
  expect_true(all(is.finite(individual_metrics$.estimate)))
})

test_that("backwards compatibility - no weights", {
  mod <- create_test_model()

  suppressWarnings({
    res <- tune_grid(
      mod,
      mpg ~ .,
      resamples = folds, # No weights
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  metrics <- collect_metrics(res)
  expect_equal(nrow(metrics), 1)
  expect_true("mean" %in% names(metrics))
  expect_true(is.numeric(metrics$mean))
  expect_false(is.na(metrics$mean))
})

test_that("rset tibble conversion includes fold weights", {
  weights <- c(0.1, 0.4, 0.5)
  weighted_folds <- add_resample_weights(folds, weights)

  # Convert to tibble manually (this is what our print method does)
  x_tbl <- tibble::as_tibble(weighted_folds)
  x_tbl$resample_weight <- weights

  # Verify the structure
  expect_true("resample_weight" %in% names(x_tbl))
  expect_equal(x_tbl$resample_weight, weights)
  expect_equal(nrow(x_tbl), 3)
})

test_that("extract_resample_weights() works with rset objects", {
  weights <- c(0.2, 0.3, 0.5)
  weighted_folds <- add_resample_weights(folds, weights)

  # Should return the weights
  extracted_weights <- extract_resample_weights(weighted_folds)
  expect_equal(extracted_weights, weights)

  # Should return NULL for unweighted rsets
  unweighted_result <- extract_resample_weights(folds)
  expect_null(unweighted_result)
})

test_that("extract_resample_weights() works with tune_results objects", {
  weights <- c(0.1, 0.5, 0.4)
  weighted_folds <- add_resample_weights(folds, weights)

  mod <- create_test_model()

  suppressWarnings({
    res <- tune_grid(
      mod,
      mpg ~ .,
      resamples = weighted_folds,
      grid = 1,
      metrics = yardstick::metric_set(yardstick::rmse),
      control = control_grid(verbose = FALSE)
    )
  })

  # Should extract weights from tune results
  extracted_weights <- extract_resample_weights(res)
  expected_weights <- weights / sum(weights) # normalized
  expect_equal(extracted_weights, expected_weights)
})

test_that("extract_resample_weights() validates input types", {
  expect_snapshot(
    extract_resample_weights("not_valid_input"),
    error = TRUE
  )

  expect_snapshot(
    extract_resample_weights(data.frame(x = 1:3)),
    error = TRUE
  )
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on April 17, 2026, 5:07 p.m.