Nothing
# ===========================================================================
# BORG Audit Functions
# ===========================================================================
#' Audit Predictions for Data Leakage
#'
#' Validates that predictions were generated correctly without data leakage.
#' Checks that predictions correspond to test data only and that the
#' prediction process did not use information from the test set.
#'
#' @param predictions Vector of predictions (numeric or factor).
#' @param train_idx Integer vector of training indices.
#' @param test_idx Integer vector of test indices.
#' @param actual Optional vector of actual values for comparison.
#' @param data Optional data frame containing the original data.
#' @param model Optional fitted model object for additional checks.
#'
#' @return A BorgRisk object with audit results.
#'
#' @examples
#' # Create data and split
#' set.seed(42)
#' data <- data.frame(y = rnorm(100), x = rnorm(100))
#' train_idx <- 1:70
#' test_idx <- 71:100
#'
#' # Fit model and predict
#' model <- lm(y ~ x, data = data[train_idx, ])
#' predictions <- predict(model, newdata = data[test_idx, ])
#'
#' # Audit predictions
#' result <- audit_predictions(predictions, train_idx, test_idx)
#'
#' @export
audit_predictions <- function(predictions, train_idx, test_idx,
actual = NULL, data = NULL, model = NULL) {
risks <- list()
# ===========================================================================
# Basic validation
# ===========================================================================
if (is.null(predictions) || length(predictions) == 0) {
stop("'predictions' must be a non-empty vector")
}
n_pred <- length(predictions)
n_test <- length(test_idx)
# Check 1: Prediction count matches test set
if (n_pred != n_test) {
risks <- c(risks, list(.new_risk(
type = "prediction_count_mismatch",
severity = "hard_violation",
description = sprintf(
"Number of predictions (%d) does not match test set size (%d)",
n_pred, n_test
),
affected_indices = test_idx,
source_object = "predictions"
)))
}
# ===========================================================================
# Perfect prediction check (potential target leak)
# ===========================================================================
if (!is.null(actual) && length(actual) == n_pred) {
if (is.numeric(predictions) && is.numeric(actual)) {
# Check for suspiciously perfect correlation
if (n_pred >= 3) {
cor_val <- tryCatch(
cor(predictions, actual, use = "pairwise.complete.obs"),
error = function(e) NA
)
if (!is.na(cor_val) && abs(cor_val) > 0.999) {
risks <- c(risks, list(.new_risk(
type = "perfect_prediction",
severity = "hard_violation",
description = sprintf(
"Predictions have near-perfect correlation with actual values (r = %.4f). Likely target leakage.",
cor_val
),
affected_indices = test_idx,
source_object = "predictions"
)))
}
}
# Check for exact matches (suspicious for continuous data)
exact_matches <- sum(abs(predictions - actual) < 1e-10, na.rm = TRUE)
if (exact_matches == n_pred && n_pred > 5) {
risks <- c(risks, list(.new_risk(
type = "exact_predictions",
severity = "hard_violation",
description = sprintf(
"All %d predictions exactly match actual values. Almost certainly target leakage.",
n_pred
),
affected_indices = test_idx,
source_object = "predictions"
)))
}
} else if (is.factor(predictions) || is.factor(actual)) {
# Classification: check for perfect accuracy
pred_char <- as.character(predictions)
actual_char <- as.character(actual)
accuracy <- mean(pred_char == actual_char, na.rm = TRUE)
if (accuracy > 0.99 && n_pred > 20) {
risks <- c(risks, list(.new_risk(
type = "suspiciously_high_accuracy",
severity = "soft_inflation",
description = sprintf(
"Classification accuracy is %.1f%% on %d test samples. Verify no target leakage.",
100 * accuracy, n_pred
),
affected_indices = test_idx,
source_object = "predictions"
)))
}
}
}
# ===========================================================================
# Model-based checks
# ===========================================================================
if (!is.null(model)) {
# Check if model was trained on correct data
model_risks <- borg_inspect(model, train_idx, test_idx, data = data)
if (model_risks@n_hard > 0) {
risks <- c(risks, list(.new_risk(
type = "model_leakage",
severity = "hard_violation",
description = sprintf(
"Model used for predictions has %d hard violations",
model_risks@n_hard
),
affected_indices = test_idx,
source_object = class(model)[1]
)))
}
}
# ===========================================================================
# Data-based checks
# ===========================================================================
if (!is.null(data)) {
# Check that we're not predicting on training data
# This would indicate predictions were made on wrong subset
n_total <- nrow(data)
# If predictions match training size instead of test size
if (n_pred == length(train_idx) && n_pred != n_test) {
risks <- c(risks, list(.new_risk(
type = "wrong_prediction_set",
severity = "hard_violation",
description = sprintf(
"Prediction count (%d) matches training set size, not test set. Predictions may be on wrong data.",
n_pred
),
affected_indices = train_idx,
source_object = "predictions"
)))
}
# If predictions match full data size
if (n_pred == n_total && n_pred != n_test) {
risks <- c(risks, list(.new_risk(
type = "full_data_prediction",
severity = "hard_violation",
description = sprintf(
"Prediction count (%d) matches full dataset size. Predictions include training data.",
n_pred
),
affected_indices = train_idx,
source_object = "predictions"
)))
}
}
# ===========================================================================
# Build result
# ===========================================================================
.make_borg_risk(risks, train_idx, test_idx, call = match.call())
}
#' Generate CV Leakage Report
#'
#' Generates a detailed report of cross-validation leakage issues.
#'
#' @param cv_object A cross-validation object (trainControl, vfold_cv, etc.).
#' @param train_idx Integer vector of training indices.
#' @param test_idx Integer vector of test indices.
#'
#' @return A list with detailed CV leakage information.
#'
#' @examples
#' # Using caret trainControl
#' if (requireNamespace("caret", quietly = TRUE)) {
#' folds <- list(Fold1 = 1:10, Fold2 = 11:20, Fold3 = 21:25)
#' ctrl <- caret::trainControl(method = "cv", index = folds)
#' report <- cv_leakage_report(ctrl, train_idx = 1:25, test_idx = 26:32)
#' print(report)
#' }
#'
#' @export
cv_leakage_report <- function(cv_object, train_idx, test_idx) {
report <- list(
summary = list(
n_train = length(train_idx),
n_test = length(test_idx),
overlap = intersect(train_idx, test_idx)
),
folds = list(),
issues = list()
)
# Detect CV type and extract folds
if (inherits(cv_object, "vfold_cv") || inherits(cv_object, "rset")) {
# rsample object
if ("splits" %in% names(cv_object)) {
for (i in seq_len(nrow(cv_object))) {
split <- cv_object$splits[[i]]
fold_id <- if ("id" %in% names(cv_object)) cv_object$id[i] else paste0("Fold", i)
# Get analysis (training) indices for this fold
analysis_idx <- NULL
if (!is.null(split$in_id)) {
analysis_idx <- split$in_id
}
if (!is.null(analysis_idx)) {
# Check for test indices in analysis set
leaked_test <- intersect(analysis_idx, test_idx)
leaked_train <- setdiff(analysis_idx, train_idx)
fold_info <- list(
fold_id = fold_id,
n_analysis = length(analysis_idx),
n_assessment = if (!is.null(split$out_id)) length(split$out_id) else NA,
leaked_test_indices = leaked_test,
n_leaked_test = length(leaked_test),
non_train_indices = leaked_train,
n_non_train = length(leaked_train)
)
report$folds[[fold_id]] <- fold_info
if (length(leaked_test) > 0) {
report$issues <- c(report$issues, list(list(
fold = fold_id,
type = "test_in_analysis",
severity = "hard_violation",
description = sprintf("%s: %d test indices in analysis set", fold_id, length(leaked_test)),
indices = leaked_test
)))
}
if (length(leaked_train) > 0) {
report$issues <- c(report$issues, list(list(
fold = fold_id,
type = "non_train_in_analysis",
severity = "hard_violation",
description = sprintf("%s: %d non-training indices in analysis set", fold_id, length(leaked_train)),
indices = leaked_train
)))
}
}
}
}
} else if (.is_trainControl(cv_object)) {
# caret trainControl
if (!is.null(cv_object$index)) {
for (i in seq_along(cv_object$index)) {
fold_idx <- cv_object$index[[i]]
fold_id <- names(cv_object$index)[i]
if (is.null(fold_id)) fold_id <- paste0("Fold", i)
leaked_test <- intersect(fold_idx, test_idx)
fold_info <- list(
fold_id = fold_id,
n_in_fold = length(fold_idx),
leaked_test_indices = leaked_test,
n_leaked_test = length(leaked_test)
)
report$folds[[fold_id]] <- fold_info
if (length(leaked_test) > 0) {
report$issues <- c(report$issues, list(list(
fold = fold_id,
type = "test_in_fold",
severity = "hard_violation",
description = sprintf("%s: %d test indices in CV fold", fold_id, length(leaked_test)),
indices = leaked_test
)))
}
}
}
}
# Summary statistics
report$summary$n_folds <- length(report$folds)
report$summary$n_issues <- length(report$issues)
report$summary$has_leakage <- report$summary$n_issues > 0
class(report) <- c("borg_cv_report", "list")
report
}
#' Audit Feature Importance Calculations
#'
#' Detects when feature importance (SHAP, permutation importance, etc.) is
#' computed using test data, which can lead to biased feature selection and
#' data leakage.
#'
#' @param importance A vector, matrix, or data frame of importance values.
#' @param data The data used to compute importance.
#' @param train_idx Integer vector of training indices.
#' @param test_idx Integer vector of test indices.
#' @param method Character indicating the importance method. One of
#' "shap", "permutation", "gain", "impurity", or "auto" (default).
#' @param model Optional fitted model object for additional validation.
#'
#' @return A BorgRisk object with audit results.
#'
#' @details
#' Feature importance computed on test data is a form of data leakage because:
#' \itemize{
#' \item SHAP values computed on test data reveal test set structure
#' \item Permutation importance on test data uses test labels
#' \item Feature selection based on test importance leads to overfit models
#' }
#'
#' This function checks if the data used for importance calculation includes
#' test indices and flags potential violations.
#'
#' @examples
#' set.seed(42)
#' data <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100))
#' train_idx <- 1:70
#' test_idx <- 71:100
#'
#' # Simulate importance values
#' importance <- c(x1 = 0.6, x2 = 0.4)
#'
#' # Good: importance computed on training data
#' result <- audit_importance(importance, data[train_idx, ], train_idx, test_idx)
#'
#' # Bad: importance computed on full data (includes test)
#' result_bad <- audit_importance(importance, data, train_idx, test_idx)
#'
#' @export
audit_importance <- function(importance, data, train_idx, test_idx,
method = "auto", model = NULL) {
risks <- list()
# ===========================================================================
# Input validation
# ===========================================================================
if (is.null(importance)) {
stop("'importance' must be provided")
}
if (is.null(data)) {
stop("'data' must be provided to check for test data usage")
}
n_data <- if (is.data.frame(data) || is.matrix(data)) nrow(data) else length(data)
n_train <- length(train_idx)
n_test <- length(test_idx)
n_total <- max(c(train_idx, test_idx))
# ===========================================================================
# Check 1: Data size indicates test data was used
# ===========================================================================
# If data size matches full dataset (train + test), likely includes test
if (n_data == n_total || n_data == (n_train + n_test)) {
risks <- c(risks, list(.new_risk(
type = "importance_on_full_data",
severity = "hard_violation",
description = sprintf(
"Feature importance computed on %d observations (full dataset). Should use only training data (%d observations).",
n_data, n_train
),
affected_indices = test_idx,
source_object = "importance"
)))
}
# If data size is larger than train but not full, something is off
if (n_data > n_train && n_data < n_total) {
risks <- c(risks, list(.new_risk(
type = "importance_data_size_mismatch",
severity = "soft_inflation",
description = sprintf(
"Feature importance data has %d observations. Expected %d (train) or %d (full). Verify correct data was used.",
n_data, n_train, n_total
),
affected_indices = integer(0),
source_object = "importance"
)))
}
# ===========================================================================
# Check 2: Method-specific checks
# ===========================================================================
method <- tolower(method)
if (method %in% c("shap", "auto")) {
# SHAP values should be computed on training data only for model explanation
# If computing on test data, it's using test information for feature selection
if (n_data == n_test) {
risks <- c(risks, list(.new_risk(
type = "shap_on_test_data",
severity = "hard_violation",
description = sprintf(
"SHAP values computed on %d observations (matches test set size). SHAP should be computed on training data for feature selection.",
n_data
),
affected_indices = test_idx,
source_object = "importance"
)))
}
}
if (method %in% c("permutation", "auto")) {
# Permutation importance requires labels - using test labels is leakage
if (n_data == n_test) {
risks <- c(risks, list(.new_risk(
type = "permutation_on_test_data",
severity = "hard_violation",
description = "Permutation importance computed on test data uses test labels, causing information leakage.",
affected_indices = test_idx,
source_object = "importance"
)))
}
}
# ===========================================================================
# Check 3: Model-based checks
# ===========================================================================
if (!is.null(model)) {
# Validate model was trained correctly
model_risk <- borg_inspect(model, train_idx, test_idx)
if (model_risk@n_hard > 0) {
risks <- c(risks, list(.new_risk(
type = "importance_from_leaked_model",
severity = "hard_violation",
description = sprintf(
"Model used for feature importance has %d hard violations. Importance values are unreliable.",
model_risk@n_hard
),
affected_indices = test_idx,
source_object = class(model)[1]
)))
}
}
# ===========================================================================
# Check 4: Importance value sanity checks
# ===========================================================================
# Extract numeric importance values
imp_vals <- if (is.data.frame(importance)) {
unlist(importance[sapply(importance, is.numeric)])
} else if (is.matrix(importance)) {
as.numeric(importance)
} else {
as.numeric(importance)
}
imp_vals <- imp_vals[!is.na(imp_vals)]
if (length(imp_vals) > 0) {
# Check for suspiciously uniform importance (might indicate random/meaningless)
if (length(imp_vals) > 2) {
imp_sd <- sd(imp_vals)
imp_range <- diff(range(imp_vals))
if (imp_range > 0 && imp_sd / imp_range < 0.05) {
risks <- c(risks, list(.new_risk(
type = "uniform_importance",
severity = "soft_inflation",
description = "Feature importance values are nearly uniform. May indicate incorrect calculation or uninformative features.",
affected_indices = integer(0),
source_object = "importance"
)))
}
}
# Check for negative importance (valid for some methods, suspicious for others)
if (any(imp_vals < 0) && method %in% c("gain", "impurity")) {
risks <- c(risks, list(.new_risk(
type = "negative_importance",
severity = "soft_inflation",
description = "Negative importance values detected for gain/impurity method. This is unusual.",
affected_indices = integer(0),
source_object = "importance"
)))
}
}
# ===========================================================================
# Build result
# ===========================================================================
.make_borg_risk(risks, train_idx, test_idx, call = match.call())
}
#' Print CV Leakage Report
#' @param x A borg_cv_report object.
#' @param ... Additional arguments (ignored).
#' @return The \code{borg_cv_report} object \code{x}, returned invisibly.
#' Called for the side effect of printing a human-readable leakage summary
#' to the console.
#' @export
print.borg_cv_report <- function(x, ...) {
cat("BORG CV Leakage Report\n")
cat("======================\n\n")
cat(sprintf("Training set: %d indices\n", x$summary$n_train))
cat(sprintf("Test set: %d indices\n", x$summary$n_test))
if (length(x$summary$overlap) > 0) {
cat(sprintf("Index overlap: %d (train/test share indices!)\n",
length(x$summary$overlap)))
}
cat(sprintf("CV folds analyzed: %d\n", x$summary$n_folds))
cat("\n")
if (x$summary$has_leakage) {
cat("--- ISSUES DETECTED ---\n\n")
for (issue in x$issues) {
cat(sprintf("[%s] %s\n", issue$severity, issue$description))
if (length(issue$indices) <= 10) {
cat(sprintf(" Affected indices: %s\n", paste(issue$indices, collapse = ", ")))
} else {
cat(sprintf(" Affected indices: %s, ... (%d more)\n",
paste(head(issue$indices, 5), collapse = ", "),
length(issue$indices) - 5))
}
cat("\n")
}
} else {
cat("No CV leakage detected.\n")
}
invisible(x)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.