View source: R/plotting_audit.R
| plot_overlap_checks | R Documentation |
Checks whether the same group identifiers appear in both the training and test partitions within each resample. This is designed to detect leakage from grouped or repeated-measures data (for example, the same subject, batch, plate, or study appearing on both sides of a fold) when group-wise splitting is expected.
plot_overlap_checks(fit, column = NULL)
fit |
A 'LeakFit' object produced by [fit_resample()]. It must contain the split indices and the associated metadata in 'fit@splits@info$coldata'. The metadata rows must align with the data used to create the splits. |
column |
Character scalar naming the metadata column to check (for example '"subject"' or '"batch"'). The function compares unique values of this column between train and test within each resample. There is no default: 'NULL' or an unknown column triggers an error. Changing 'column' changes which kind of leakage (subject-level, batch-level, etc.) is tested and therefore the overlap counts. |
For each resample in 'fit@splits@indices', the function counts the number of unique values of 'column' in the train and test sets and the size of their intersection. Any non-zero overlap indicates that at least one group appears in both train and test for that resample. The check is metadata-based only: it relies on exact matches of the supplied column and does not inspect features or outcomes. It only checks train vs test within each resample, so it will not detect overlaps across different resamples or other leakage mechanisms. Inconsistent IDs or missing values in the metadata can hide or inflate overlaps. 'NA' values are treated as regular identifiers and will count toward overlap if they appear in both partitions. Requires ggplot2.
A list returned invisibly with:
'overlap_counts': data.frame with one row per resample and columns 'fold' (resample index in 'fit@splits@indices'), 'overlap' (unique IDs shared by train and test), 'train' (unique IDs in train), and 'test' (unique IDs in test).
'column': the metadata column name used for the check.
'plot': the ggplot object showing the three count series across folds.
The plot is also printed. When any overlap is detected, the plot adds a warning annotation.
set.seed(1)
df <- data.frame(
subject = rep(1:6, each = 2),
outcome = rbinom(12, 1, 0.5),
x1 = rnorm(12),
x2 = rnorm(12)
)
splits <- make_split_plan(df, outcome = "outcome",
mode = "subject_grouped", group = "subject", v = 3)
custom <- list(
glm = list(
fit = function(x, y, task, weights, ...) {
stats::glm(y ~ ., data = as.data.frame(x),
family = stats::binomial(), weights = weights)
},
predict = function(object, newdata, task, ...) {
as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
type = "response"))
}
)
)
fit <- fit_resample(df, outcome = "outcome", splits = splits,
learner = "glm", custom_learners = custom,
metrics = "accuracy", refit = FALSE)
if (requireNamespace("ggplot2", quietly = TRUE)) {
out <- plot_overlap_checks(fit, column = "subject")
out$overlap_counts
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.