Nothing
# Test file for variable fold weights functionality
# Setup test data
set.seed(42)
test_data <- data.frame(
x1 = rnorm(50),
x2 = rnorm(50),
x3 = rnorm(50)
)
test_data$y <- 2 * test_data$x1 + 3 * test_data$x2 + rnorm(50, sd = 0.5)
set.seed(123)
folds <- rsample::vfold_cv(mtcars, v = 3)
# Helper function to create a simple model
create_test_model <- function() {
parsnip::linear_reg() |> parsnip::set_engine("lm")
}
test_that("add_resample_weights() validates inputs correctly", {
expect_snapshot(
add_resample_weights("not_an_rset", c(0.5, 0.3, 0.2)),
error = TRUE
)
expect_snapshot(
add_resample_weights(folds, c("a", "b", "c")),
error = TRUE
)
expect_snapshot(
add_resample_weights(folds, c(0.5, 0.3)),
error = TRUE
)
expect_snapshot(
add_resample_weights(folds, c(-0.1, 0.5, 0.6)),
error = TRUE
)
expect_snapshot(
add_resample_weights(folds, c(0, 0, 0)),
error = TRUE
)
})
test_that("add_resample_weights() adds weights correctly", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
# Weights get normalized to sum to 1
expected_weights <- weights / sum(weights)
expect_s3_class(weighted_folds, "rset")
expect_equal(attr(weighted_folds, ".resample_weights"), expected_weights)
expect_equal(nrow(weighted_folds), nrow(folds))
})
test_that("calculate_resample_weights() works correctly", {
auto_weights <- calculate_resample_weights(folds)
expect_type(auto_weights, "double")
expect_length(auto_weights, nrow(folds))
expect_true(all(auto_weights > 0))
expect_true(abs(sum(auto_weights) - 1) < 1e-10)
})
test_that("weights are preserved through tuning pipeline", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
mod <- create_test_model()
suppressWarnings({
res <- tune_grid(
mod,
mpg ~ .,
resamples = weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
metrics <- collect_metrics(res)
expect_equal(nrow(metrics), 1)
expect_true("mean" %in% names(metrics))
expect_true(is.numeric(metrics$mean))
})
test_that("weights affect metric aggregation", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
mod <- create_test_model()
suppressWarnings({
# Unweighted results
res_unweighted <- tune_grid(
mod,
mpg ~ .,
resamples = folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
# Weighted results
res_weighted <- tune_grid(
mod,
mpg ~ .,
resamples = weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
unweighted_rmse <- collect_metrics(res_unweighted)$mean[1]
weighted_rmse <- collect_metrics(res_weighted)$mean[1]
expect_true(is.numeric(unweighted_rmse))
expect_true(is.numeric(weighted_rmse))
expect_false(is.na(unweighted_rmse))
expect_false(is.na(weighted_rmse))
})
test_that("extreme weights show larger effect", {
skip_if_not_installed("kknn")
# Create folds for this specific test
set.seed(42)
test_folds <- rsample::vfold_cv(test_data, v = 3)
# Regular weights
weights <- c(0.6, 0.2, 0.2)
weighted_folds <- add_resample_weights(test_folds, weights)
# Extreme weights
extreme_weights <- c(0.95, 0.025, 0.025)
extreme_weighted_folds <- add_resample_weights(test_folds, extreme_weights)
# Create a model with tuning parameter
knn_spec <- parsnip::nearest_neighbor(neighbors = tune()) |>
parsnip::set_engine("kknn") |>
parsnip::set_mode("regression")
param_grid <- data.frame(neighbors = c(3, 5))
suppressWarnings({
# Unweighted
res_unweighted <- tune_grid(
knn_spec,
y ~ .,
resamples = test_folds,
grid = param_grid,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
# Regular weights
res_weighted <- tune_grid(
knn_spec,
y ~ .,
resamples = weighted_folds,
grid = param_grid,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
# Extreme weights
res_extreme <- tune_grid(
knn_spec,
y ~ .,
resamples = extreme_weighted_folds,
grid = param_grid,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
unweighted_metrics <- collect_metrics(res_unweighted)
weighted_metrics <- collect_metrics(res_weighted)
extreme_metrics <- collect_metrics(res_extreme)
# Check that results exist and are sensible
expect_equal(nrow(unweighted_metrics), 2)
expect_equal(nrow(weighted_metrics), 2)
expect_equal(nrow(extreme_metrics), 2)
# Calculate differences
regular_diff <- max(abs(unweighted_metrics$mean - weighted_metrics$mean))
extreme_diff <- max(abs(unweighted_metrics$mean - extreme_metrics$mean))
expect_true(regular_diff >= 0)
expect_true(extreme_diff >= 0)
expect_true(all(is.finite(c(regular_diff, extreme_diff))))
})
test_that("weight normalization works correctly", {
expect_equal(
tune:::.validate_resample_weights(c(3, 6, 9), 3),
c(1 / 6, 1 / 3, 1 / 2) # normalized to sum to 1
)
expect_equal(
tune:::.validate_resample_weights(c(0.2, 0.3, 0.5), 3),
c(0.2, 0.3, 0.5) # already normalized to sum to 1
)
})
test_that("equal weights return NULL", {
# Simplest integer match
expect_null(tune:::.validate_resample_weights(c(2, 2, 2), 3))
# Fractional match
expect_null(tune:::.validate_resample_weights(c(1 / 3, 1 / 3, 1 / 3), 3))
# Check more reseampless
expect_null(tune:::.validate_resample_weights(c(1, 1, 1, 1, 1), 5))
})
test_that("unequal weights do not return NULL", {
# Check non-null decimal values
result <- tune:::.validate_resample_weights(c(0.1, 0.5, 0.4), 3)
expect_false(is.null(result))
expect_equal(result, c(0.1, 0.5, 0.4))
# Non-null fractional values
result2 <- tune:::.validate_resample_weights(c(1, 2, 3), 3)
expect_false(is.null(result2))
expect_equal(result2, c(1 / 6, 2 / 6, 3 / 6))
})
test_that("add_resample_weights with equal weights returns NULL attribute", {
# Adding equal weights should trigger NULL assignment
equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1))
expect_null(attr(equal_weighted_folds, ".resample_weights"))
# Verify it's still an rset object
expect_s3_class(equal_weighted_folds, "rset")
})
test_that("equal weights produce same results as no weights", {
mod <- create_test_model()
suppressWarnings({
# Results with no weights
res_no_weights <- tune_grid(
mod,
mpg ~ .,
resamples = folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
# Results with equal weights
equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1))
res_equal_weights <- tune_grid(
mod,
mpg ~ .,
resamples = equal_weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
metrics_no_weights <- collect_metrics(res_no_weights)
metrics_equal_weights <- collect_metrics(res_equal_weights)
# Results should match
expect_equal(metrics_no_weights$mean, metrics_equal_weights$mean)
expect_equal(metrics_no_weights$std_err, metrics_equal_weights$std_err)
})
test_that("weighted statistics functions work correctly", {
x <- c(1, 2, 3, 4, 5)
w <- c(0.1, 0.2, 0.3, 0.2, 0.2)
weighted_sd <- tune:::.weighted_sd(x, w)
expect_true(is.numeric(weighted_sd))
expect_false(is.na(weighted_sd))
expect_true(weighted_sd >= 0)
# Test with NA values
x_na <- c(1, 2, NA, 4, 5)
weighted_sd_na <- tune:::.weighted_sd(x_na[!is.na(x_na)], w[!is.na(x_na)])
expect_true(is.numeric(weighted_sd_na))
# Test edge cases
expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value
})
test_that("fold weight extraction works", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
# Weights get normalized to sum to 1
expected_weights <- weights / sum(weights)
mod <- create_test_model()
suppressWarnings({
res <- tune_grid(
mod,
mpg ~ .,
resamples = weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
extracted_weights <- tune:::.get_resample_weights(res)
expect_equal(extracted_weights, expected_weights)
})
test_that("individual fold metrics can be collected", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
mod <- create_test_model()
suppressWarnings({
res <- tune_grid(
mod,
mpg ~ .,
resamples = weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
# Collect individual fold metrics
individual_metrics <- collect_metrics(res, summarize = FALSE)
expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold
expect_true("id" %in% names(individual_metrics))
expect_true(".estimate" %in% names(individual_metrics))
expect_true(all(is.finite(individual_metrics$.estimate)))
})
test_that("backwards compatibility - no weights", {
mod <- create_test_model()
suppressWarnings({
res <- tune_grid(
mod,
mpg ~ .,
resamples = folds, # No weights
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
metrics <- collect_metrics(res)
expect_equal(nrow(metrics), 1)
expect_true("mean" %in% names(metrics))
expect_true(is.numeric(metrics$mean))
expect_false(is.na(metrics$mean))
})
test_that("rset tibble conversion includes fold weights", {
weights <- c(0.1, 0.4, 0.5)
weighted_folds <- add_resample_weights(folds, weights)
# Convert to tibble manually (this is what our print method does)
x_tbl <- tibble::as_tibble(weighted_folds)
x_tbl$resample_weight <- weights
# Verify the structure
expect_true("resample_weight" %in% names(x_tbl))
expect_equal(x_tbl$resample_weight, weights)
expect_equal(nrow(x_tbl), 3)
})
test_that("extract_resample_weights() works with rset objects", {
weights <- c(0.2, 0.3, 0.5)
weighted_folds <- add_resample_weights(folds, weights)
# Should return the weights
extracted_weights <- extract_resample_weights(weighted_folds)
expect_equal(extracted_weights, weights)
# Should return NULL for unweighted rsets
unweighted_result <- extract_resample_weights(folds)
expect_null(unweighted_result)
})
test_that("extract_resample_weights() works with tune_results objects", {
weights <- c(0.1, 0.5, 0.4)
weighted_folds <- add_resample_weights(folds, weights)
mod <- create_test_model()
suppressWarnings({
res <- tune_grid(
mod,
mpg ~ .,
resamples = weighted_folds,
grid = 1,
metrics = yardstick::metric_set(yardstick::rmse),
control = control_grid(verbose = FALSE)
)
})
# Should extract weights from tune results
extracted_weights <- extract_resample_weights(res)
expected_weights <- weights / sum(weights) # normalized
expect_equal(extracted_weights, expected_weights)
})
test_that("extract_resample_weights() validates input types", {
expect_snapshot(
extract_resample_weights("not_valid_input"),
error = TRUE
)
expect_snapshot(
extract_resample_weights(data.frame(x = 1:3)),
error = TRUE
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.