tests/testthat/helper-samplers.R

# Custom testthat expectations for FeatureSampler testing
#
# These follow testthat 3e conventions.
# See: https://testthat.r-lib.org/articles/custom-expectation.html

# -----------------------------------------------------------------------------
# Helper: Generate task from sampler's supported feature types
# -----------------------------------------------------------------------------

#' Generate a test task based on sampler's supported feature types
#'
#' @param supported_types Character vector of supported feature types
#' @param n Number of observations
#' @return An mlr3 regression task
generate_test_task = function(supported_types, n = 100) {
	xdat = data.table::data.table(
		x_num1 = rnorm(n),
		x_num2 = runif(n)
	)

	if ("integer" %in% supported_types) {
		xdat[, x_int := sample(1L:10L, n, replace = TRUE)]
	}
	if ("factor" %in% supported_types) {
		xdat[, x_fct := factor(sample(c("a", "b", "c"), n, replace = TRUE))]
	}
	if ("ordered" %in% supported_types) {
		xdat[,
			x_ord := ordered(
				sample(c("low", "mid", "high"), n, replace = TRUE),
				levels = c("low", "mid", "high")
			)
		]
	}
	if ("logical" %in% supported_types) {
		xdat[, x_lgl := sample(c(TRUE, FALSE), n, replace = TRUE)]
	}

	xdat[, y := x_num1 + 0.5 * x_num2 + rnorm(n, sd = 0.1)]
	mlr3::as_task_regr(xdat, target = "y")
}

# -----------------------------------------------------------------------------
# expect_feature_type_consistency
# -----------------------------------------------------------------------------

#' Expect sampled output has consistent feature types with task
#'
#' Compares feature classes in sampled data against task$feature_types.
#'
#' @param sampled A data.table returned from sampler$sample()
#' @param task The mlr3 task used for sampling
#' @return Invisibly returns sampled data for piping
expect_feature_type_consistency = function(sampled, task) {
	expected_classes = stats::setNames(
		task$feature_types$type,
		task$feature_types$id
	)
	actual_classes = sapply(
		sampled[, task$feature_names, with = FALSE],
		function(x) class(x)[1]
	)

	mismatches = expected_classes != actual_classes
	if (any(mismatches)) {
		bad_feats = names(expected_classes)[mismatches]
		msg = glue::glue(
			"Feature type mismatch in sampled data.\n",
			"Features with wrong type: {paste(bad_feats, collapse = ', ')}\n",
			"Expected: {paste(expected_classes[mismatches], collapse = ', ')}\n",
			"Actual: {paste(actual_classes[mismatches], collapse = ', ')}"
		)
	} else {
		msg = ""
	}

	expect(!any(mismatches), msg)
	invisible(sampled)
}

# -----------------------------------------------------------------------------
# expect_non_sampled_unchanged
# -----------------------------------------------------------------------------

#' Expect non-sampled features are unchanged after sampling
#'
#' Verifies that features not passed to $sample() remain identical
#' between the original data and the sampled output.
#'
#' @param sampled A data.table returned from sampler$sample()
#' @param original The original data before sampling
#' @param features Character vector of feature names that should be unchanged
#' @return Invisibly returns sampled data for piping
expect_non_sampled_unchanged = function(sampled, original, features) {
	if (length(features) == 0) {
		expect(TRUE, "No features to check")
		return(invisible(sampled))
	}

	for (feat in features) {
		ok = identical(sampled[[feat]], original[[feat]])
		expect(
			ok,
			glue::glue(
				"Feature '{feat}' was modified in sampled data.\n",
				"Non-sampled features must remain unchanged during sampling."
			)
		)
	}

	invisible(sampled)
}

# -----------------------------------------------------------------------------
# expect_sampled_features_changed
# -----------------------------------------------------------------------------

#' Expect sampled features differ from original (stochastic check)
#'
#' Verifies that at least some values in sampled features differ from original.
#' This is a probabilistic check - with sufficient data variability, randomly
#' sampling identical values is extremely unlikely.
#'
#' @param sampled A data.table returned from sampler$sample()
#' @param original The original data before sampling
#' @param sampled_features Character vector of feature names that were sampled
#' @return Invisibly returns sampled data for piping
expect_sampled_features_changed = function(sampled, original, sampled_features) {
	for (feat in sampled_features) {
		ok = !identical(sampled[[feat]], original[[feat]])
		expect(
			ok,
			glue::glue(
				"Sampled feature '{feat}' is identical to original.\n",
				"This suggests the sampler did not modify the feature.\n",
				"(Note: This could theoretically be a false positive with very low probability)"
			)
		)
	}

	invisible(sampled)
}

# -----------------------------------------------------------------------------
# expect_sampler_output_structure
# -----------------------------------------------------------------------------

#' Expect sampler output has correct structure
#'
#' Verifies the sampled data is a data.table with correct columns and dimensions.
#'
#' @param sampled A data.table returned from sampler$sample()
#' @param task The mlr3 task used for sampling
#' @param nrows Expected number of rows (NULL to skip check)
#' @return Invisibly returns sampled data for piping
expect_sampler_output_structure = function(sampled, task, nrows = NULL) {
	expect(
		data.table::is.data.table(sampled),
		glue::glue(
			"Sampled data is not a data.table.\n",
			"Actual class: {paste(class(sampled), collapse = ', ')}"
		)
	)

	expected_cols = c(task$target_names, task$feature_names)
	actual_cols = names(sampled)

	expect(
		identical(actual_cols, expected_cols),
		glue::glue(
			"Sampled data has incorrect columns.\n",
			"Expected: {paste(expected_cols, collapse = ', ')}\n",
			"Actual: {paste(actual_cols, collapse = ', ')}"
		)
	)

	if (!is.null(nrows)) {
		expect(
			nrow(sampled) == nrows,
			glue::glue(
				"Sampled data has incorrect number of rows.\n",
				"Expected: {nrows}\n",
				"Actual: {nrow(sampled)}"
			)
		)
	}

	invisible(sampled)
}

# -----------------------------------------------------------------------------
# expect_marginal_sampling
# -----------------------------------------------------------------------------

#' Expect conditional sampler handles marginal sampling correctly
#'
#' Tests that a conditional sampler works correctly with an empty conditioning
#' set (character(0)), which should trigger marginal sampling behavior.
#'
#' @param sampler A ConditionalSampler instance
#' @param feature Feature to sample
#' @param row_ids Row IDs to sample
#' @return Invisibly returns sampled data for piping
expect_marginal_sampling = function(sampler, feature, row_ids = 1:10) {
	expect(
		inherits(sampler, "ConditionalSampler"),
		"Sampler is not a ConditionalSampler. Marginal sampling test only applies to conditional samplers."
	)

	original = sampler$task$data(rows = row_ids)

	# Sample with empty conditioning set (marginal sampling)
	sampled = sampler$sample(
		feature = feature,
		row_ids = row_ids,
		conditioning_set = character(0)
	)

	# Check structure
	expect_sampler_output_structure(sampled, sampler$task, nrows = length(row_ids))

	# Check types
	expect_feature_type_consistency(sampled, sampler$task)

	# Non-sampled features should remain unchanged. The key test here is that
	# conditional samplers can handle an empty conditioning set without error.
	non_sampled = setdiff(sampler$task$feature_names, feature)
	expect_non_sampled_unchanged(sampled, original, non_sampled)

	invisible(sampled)
}

# -----------------------------------------------------------------------------
# expect_conditional_sampling
# -----------------------------------------------------------------------------

#' Expect conditional sampler handles conditional sampling correctly
#'
#' Tests that a conditional sampler correctly preserves conditioning features
#' and modifies sampled features.
#'
#' @param sampler A ConditionalSampler instance
#' @param feature Feature(s) to sample
#' @param conditioning_set Features to condition on
#' @param row_ids Row IDs to sample
#' @return Invisibly returns sampled data for piping
expect_conditional_sampling = function(sampler, feature, conditioning_set, row_ids = 1:10) {
	expect(
		inherits(sampler, "ConditionalSampler"),
		"Sampler is not a ConditionalSampler. Conditional sampling test only applies to conditional samplers."
	)

	original = sampler$task$data(rows = row_ids)

	sampled = sampler$sample(
		feature = feature,
		row_ids = row_ids,
		conditioning_set = conditioning_set
	)

	# Check structure
	expect_sampler_output_structure(sampled, sampler$task, nrows = length(row_ids))

	# Check types
	expect_feature_type_consistency(sampled, sampler$task)

	# Non-sampled features must be unchanged
	expect_non_sampled_unchanged(sampled, original, conditioning_set)

	# Sampled features should change (stochastic check)
	expect_sampled_features_changed(sampled, original, feature)

	invisible(sampled)
}

# -----------------------------------------------------------------------------
# Omnibus test functions (combine multiple expectations)
# -----------------------------------------------------------------------------

#' Run comprehensive feature type tests for a sampler class
#'
#' Generates a task based on the sampler's supported feature types and tests
#' that all sampling operations preserve correct types.
#'
#' @param sampler_class R6 class for the sampler to test
#' @param ... Additional arguments passed to sampler constructor
#' @return NULL (used for side effects via testthat expectations)
test_sampler_feature_types = function(sampler_class, ...) {
	supported_types = sampler_class$public_fields$feature_types
	task = generate_test_task(supported_types)

	sampler = sampler_class$new(task, ...)
	is_conditional = inherits(sampler, "ConditionalSampler")

	# Test sampling each feature
	# Use enough rows to make false positives from identical sampling negligible,
	# especially for low-cardinality features like ordered factors
	test_row_ids = 1:50

	for (feat in task$feature_names) {
		sampled = sampler$sample(feat, row_ids = test_row_ids)
		expect_sampler_output_structure(sampled, task, nrows = length(test_row_ids))
		expect_feature_type_consistency(sampled, task)

		# For conditional samplers, also test with explicit conditioning set
		if (is_conditional) {
			other_feats = setdiff(task$feature_names, feat)
			if (length(other_feats) >= 1) {
				expect_conditional_sampling(
					sampler,
					feature = feat,
					conditioning_set = other_feats[1],
					row_ids = test_row_ids
				)
			}
			# Also test marginal case
			expect_marginal_sampling(sampler, feature = feat, row_ids = test_row_ids)
		}
	}

	invisible(NULL)
}

#' Test conditioning_set parameter behavior for conditional samplers
#'
#' Verifies that a conditional sampler correctly:
#' 1. Stores conditioning_set in param_set when provided during initialization
#' 2. Can sample without specifying conditioning_set (uses stored value)
#' 3. Can override conditioning_set in $sample() calls
#' 4. Handles NULL conditioning_set (defaults to all other features)
#' 5. Handles empty conditioning_set character(0) (marginal sampling)
#'
#' @param sampler_class R6 class for the sampler to test
#' @param task mlr3 Task to use for testing (must have at least 3 features)
#' @param ... Additional arguments passed to sampler constructor
#' @return NULL (used for side effects via testthat expectations)
test_conditioning_set_behavior = function(sampler_class, task, ...) {
	features = task$feature_names
	checkmate::assert_true(
		length(features) >= 3,
		.var.name = "task must have at least 3 features"
	)

	target_feature = features[1]
	cond_set_1 = features[2]
	cond_set_2 = features[3]
	other_features = setdiff(features, target_feature)

	# Test 1: conditioning_set stored in param_set when provided
	sampler_with_cond = sampler_class$new(task, conditioning_set = cond_set_1, ...)
	expect_identical(
		sampler_with_cond$param_set$values$conditioning_set,
		cond_set_1
	)

	# Test 2: Can sample using stored conditioning_set
	original_data = task$data(rows = 1:5)
	result_stored = sampler_with_cond$sample(feature = target_feature, row_ids = 1:5)
	expect_sampler_output_structure(result_stored, task, nrows = 5)
	expect_non_sampled_unchanged(result_stored, original_data, cond_set_1)

	# Test 3: Can override conditioning_set in $sample() call
	result_override = sampler_with_cond$sample(
		feature = target_feature,
		row_ids = 1:5,
		conditioning_set = cond_set_2
	)
	expect_sampler_output_structure(result_override, task, nrows = 5)
	expect_non_sampled_unchanged(result_override, original_data, cond_set_2)

	# Test 4: NULL conditioning_set during initialization
	sampler_no_cond = sampler_class$new(task, ...)
	expect_null(sampler_no_cond$param_set$values$conditioning_set)

	# Test 5: Can specify conditioning_set in $sample() when not set during init
	result_specified = sampler_no_cond$sample(
		feature = target_feature,
		row_ids = 1:5,
		conditioning_set = cond_set_1
	)
	expect_non_sampled_unchanged(result_specified, original_data, cond_set_1)

	# Test 6: NULL conditioning_set should default to all other features
	# Verify by checking that all other features remain unchanged (they were conditioned on)
	result_null = sampler_no_cond$sample(
		feature = target_feature,
		row_ids = 1:5,
		conditioning_set = NULL
	)
	expect_sampler_output_structure(result_null, task, nrows = 5)
	# If NULL defaults to all other features as conditioning set, they should be unchanged
	expect_non_sampled_unchanged(result_null, original_data, other_features)

	# Test 7: character(0) should result in empty conditioning set (marginal sampling)
	# This tests that the sampler handles empty conditioning set without error
	result_empty = sampler_no_cond$sample(
		feature = target_feature,
		row_ids = 1:5,
		conditioning_set = character(0)
	)
	expect_sampler_output_structure(result_empty, task, nrows = 5)
	# The sampled feature should still be modified
	expect_sampled_features_changed(result_empty, original_data, target_feature)

	invisible(NULL)
}

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.