tests/testthat/test-pinterval_mondrian.R

# Tests for pinterval_mondrian()

# --- Setup ---
set.seed(42)
n <- 600
x <- runif(n)
group <- sample(c("A", "B", "C"), n, replace = TRUE)
mu <- ifelse(group == "A", 1 + x, ifelse(group == "B", 2 + x, 3 + x))
y <- rnorm(n, mean = mu, sd = 0.5)

pred_cal <- mu[1:300]
truth_cal <- y[1:300]
class_cal <- group[1:300]
pred_test <- mu[301:600]
truth_test <- y[301:600]
class_test <- group[301:600]

# ============================================================
# 1. Input validation
# ============================================================

test_that("pred must be numeric", {
	expect_error(
		pinterval_mondrian(pred = "a", pred_class = class_test, calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal),
		"pinterval_mondrian.*pred.*numeric"
	)
})

test_that("calib must be provided", {
	expect_error(
		pinterval_mondrian(pred = pred_test, pred_class = class_test, calib = NULL),
		"pinterval_mondrian.*calib.*provided"
	)
})

test_that("pred_class and pred must have the same length", {
	expect_error(
		pinterval_mondrian(
			pred = pred_test, pred_class = class_test[1:10],
			calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal
		),
		"pred_class.*same length"
	)
})

test_that("calib_class and calib must have the same length", {
	expect_error(
		pinterval_mondrian(
			pred = pred_test, pred_class = class_test,
			calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal[1:10]
		),
		"calib_class.*same length"
	)
})

test_that("alpha must be valid", {
	expect_error(
		pinterval_mondrian(
			pred = pred_test, pred_class = class_test,
			calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal,
			alpha = 0
		),
		"alpha"
	)
})

# ============================================================
# 2. Output structure
# ============================================================

test_that("output is a tibble with correct columns", {
	result <- pinterval_mondrian(
		pred = pred_test, pred_class = class_test,
		calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal,
		alpha = 0.1
	)
	expect_s3_class(result, "tbl_df")
	expect_true(all(c("pred", "lower_bound", "upper_bound") %in% names(result)))
	expect_equal(nrow(result), length(pred_test))
})

# ============================================================
# 3. Correctness / coverage
# ============================================================

test_that("per-class coverage is approximately 1-alpha", {
	set.seed(123)
	result <- pinterval_mondrian(
		pred = pred_test, pred_class = class_test,
		calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal,
		alpha = 0.1, lower_bound = -2, upper_bound = 8
	)

	for (cls in unique(class_test)) {
		idx <- class_test == cls
		coverage_cls <- mean(
			truth_test[idx] >= result$lower_bound[idx] &
				truth_test[idx] <= result$upper_bound[idx],
			na.rm = TRUE
		)
		expect_true(
			coverage_cls >= 0.80,
			info = paste("Class", cls, "coverage was", coverage_cls)
		)
	}
})

# ============================================================
# 4. Warns about unseen classes in pred
# ============================================================

test_that("warns about classes in pred not in calib", {
	pred_class_new <- class_test
	pred_class_new[1:3] <- "D"  # class D not in calibration
	expect_warning(
		pinterval_mondrian(
			pred = pred_test, pred_class = pred_class_new,
			calib = pred_cal, calib_truth = truth_cal, calib_class = class_cal,
			alpha = 0.1
		),
		"class"
	)
})

Try the pintervals package in your browser

Any scripts or data that you put into this service are public.

pintervals documentation built on March 3, 2026, 5:06 p.m.