R/plotting_audit.R

Defines functions plot_confounder_sensitivity plot_calibration plot_time_acf plot_overlap_checks plot_fold_balance plot_perm_distribution

Documented in plot_calibration plot_confounder_sensitivity plot_fold_balance plot_overlap_checks plot_perm_distribution plot_time_acf

# Diagnostic plotting helpers --------------------------------------------------

#' Plot permutation distribution for a LeakAudit object
#'
#' Visualizes the label-permutation metric distribution and marks the observed
#' and permuted-mean values to help assess leakage signals. Requires ggplot2.
#'
#' @param audit LeakAudit.
#' @return A list containing the observed value, permuted mean, permutation values,
#'   and a ggplot object.
#' @examples
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   set.seed(42)
#'   df <- data.frame(
#'     subject = rep(1:15, each = 2),
#'     outcome = factor(rep(c(0, 1), 15)),
#'     x1 = rnorm(30),
#'     x2 = rnorm(30)
#'   )
#'   splits <- make_split_plan(df, outcome = "outcome",
#'                             mode = "subject_grouped", group = "subject",
#'                             v = 3, progress = FALSE)
#'   custom <- list(
#'     glm = list(
#'       fit = function(x, y, task, weights, ...) {
#'         stats::glm(y ~ ., data = as.data.frame(x),
#'                    family = stats::binomial(), weights = weights)
#'       },
#'       predict = function(object, newdata, task, ...) {
#'         as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
#'                                   type = "response"))
#'       }
#'     )
#'   )
#'   fit <- fit_resample(df, outcome = "outcome", splits = splits,
#'                       learner = "glm", custom_learners = custom,
#'                       metrics = "auc", refit = FALSE, seed = 1)
#'   audit <- audit_leakage(fit, metric = "auc", B = 20)
#'   plot_perm_distribution(audit)
#' }
#'
#' @export
plot_perm_distribution <- function(audit) {
  stopifnot(inherits(audit, "LeakAudit"))
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_perm_distribution().",
         call. = FALSE)
  }
  perm <- audit@perm_values
  perm <- perm[is.finite(perm)]
  if (!length(perm)) {
    stop("No finite permutation values available for plotting.", call. = FALSE)
  }
  obs <- audit@permutation_gap$metric_obs
  perm_mean <- mean(perm, na.rm = TRUE)

  hist_obj <- graphics::hist(perm, breaks = "FD", plot = FALSE)
  df <- data.frame(mid = hist_obj$mids, count = hist_obj$counts)
  bin_width <- if (length(hist_obj$breaks) > 1L) diff(hist_obj$breaks)[1] else 1
  line_df <- data.frame(
    value = c(obs, perm_mean),
    type = factor(c("Observed", "Permuted mean"), levels = c("Observed", "Permuted mean"))
  )

  p <- ggplot2::ggplot(df, ggplot2::aes(x = mid, y = count)) +
    ggplot2::geom_col(width = bin_width, fill = "grey80", color = "white") +
    ggplot2::geom_vline(data = line_df,
                        ggplot2::aes(xintercept = value, color = type, linetype = type),
                        linewidth = 1) +
    ggplot2::labs(title = "Permutation distribution", x = "Metric", y = "Count",
                  color = NULL, linetype = NULL) +
    ggplot2::scale_color_manual(values = c("Observed" = "red", "Permuted mean" = "blue")) +
    ggplot2::scale_linetype_manual(values = c("Observed" = "solid",
                                              "Permuted mean" = "dashed")) +
    ggplot2::theme_minimal() +
    ggplot2::theme(legend.position = "top")

  if (interactive()) print(p)
  invisible(list(
    observed = obs,
    permuted_mean = perm_mean,
    perm_values = perm,
    plot = p
  ))
}

#' Plot fold balance of class counts per fold
#'
#' Displays a bar chart of class counts per fold. For binomial tasks, it also
#' overlays the positive proportion to diagnose stratification issues. The
#' positive class is taken from \code{fit@info$positive_class} when available;
#' otherwise the second factor level is used. For multiclass tasks, the plot
#' shows per-class counts without a proportion line. Only available for
#' classification tasks. Requires ggplot2.
#'
#' @param fit LeakFit.
#' @return A list containing the fold summary, positive class (if binomial),
#'   and a ggplot object.
#' @examples
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   set.seed(42)
#'   df <- data.frame(
#'     subject = rep(1:15, each = 2),
#'     outcome = factor(rep(c(0, 1), 15)),
#'     x1 = rnorm(30),
#'     x2 = rnorm(30)
#'   )
#'   splits <- make_split_plan(df, outcome = "outcome",
#'                             mode = "subject_grouped", group = "subject",
#'                             v = 3, progress = FALSE)
#'   custom <- list(
#'     glm = list(
#'       fit = function(x, y, task, weights, ...) {
#'         stats::glm(y ~ ., data = as.data.frame(x),
#'                    family = stats::binomial(), weights = weights)
#'       },
#'       predict = function(object, newdata, task, ...) {
#'         as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
#'                                   type = "response"))
#'       }
#'     )
#'   )
#'   fit <- fit_resample(df, outcome = "outcome", splits = splits,
#'                       learner = "glm", custom_learners = custom,
#'                       metrics = "auc", refit = FALSE, seed = 1)
#'   plot_fold_balance(fit)
#' }
#'
#' @export
plot_fold_balance <- function(fit) {
  stopifnot(inherits(fit, "LeakFit"))
  if (!fit@task %in% c("binomial", "multiclass")) {
    stop("plot_fold_balance is only available for classification tasks.", call. = FALSE)
  }
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_fold_balance().",
         call. = FALSE)
  }
  if (fit@task == "multiclass") {
    class_levels <- NULL
    for (df in fit@predictions) {
      if (is.factor(df$truth)) {
        class_levels <- levels(df$truth)
        break
      }
    }
    if (is.null(class_levels)) {
      class_levels <- sort(unique(unlist(lapply(fit@predictions, function(df) {
        as.character(df$truth)
      }))))
    }
    if (!length(class_levels)) {
      stop("No class labels available for plotting.", call. = FALSE)
    }
    tab <- lapply(seq_along(fit@predictions), function(i) {
      df <- fit@predictions[[i]]
      y <- factor(as.character(df$truth), levels = class_levels)
      counts <- table(y)
      data.frame(
        fold = i,
        class = factor(names(counts), levels = class_levels),
        count = as.numeric(counts),
        stringsAsFactors = FALSE
      )
    })
    tab <- do.call(rbind, tab)
    p <- ggplot2::ggplot(tab, ggplot2::aes(x = fold, y = count, fill = class)) +
      ggplot2::geom_col(width = 0.7, color = "white") +
      ggplot2::scale_x_continuous(breaks = unique(tab$fold)) +
      ggplot2::labs(title = "Fold class balance", x = "Fold", y = "Count", fill = "Class") +
      ggplot2::theme_minimal() +
      ggplot2::theme(legend.position = "top")
    if (interactive()) print(p)
    return(invisible(list(
      fold_summary = tab,
      positive_class = NA_character_,
      plot = p
    )))
  }
  pos_class <- fit@info$positive_class
  if (length(pos_class) != 1L) pos_class <- NULL
  pos_class <- if (!is.null(pos_class)) as.character(pos_class) else NULL
  if (!is.null(pos_class) && (is.na(pos_class) || !nzchar(pos_class))) {
    pos_class <- NULL
  }
  if (is.null(pos_class)) {
    for (df in fit@predictions) {
      if (is.factor(df$truth) && nlevels(df$truth) >= 2) {
        pos_class <- levels(df$truth)[2]
        break
      }
    }
  }
  resolve_pos_label <- function(y, pos_label) {
    if (!is.null(pos_label)) return(pos_label)
    if (is.factor(y) && nlevels(y) >= 2) return(levels(y)[2])
    "1"
  }
  tab <- lapply(seq_along(fit@predictions), function(i) {
    df <- fit@predictions[[i]]
    y <- df$truth
    pos_label <- resolve_pos_label(y, pos_class)
    y_chr <- as.character(y)
    is_pos <- y_chr == pos_label
    valid <- !is.na(is_pos)
    data.frame(fold = i,
               positives = sum(is_pos[valid]),
               negatives = sum(!is_pos[valid]))
  })
  tab <- do.call(rbind, tab)
  totals <- tab$positives + tab$negatives
  tab$prop_pos <- ifelse(totals > 0, tab$positives / totals, NA_real_)
  pos_legend <- if (!is.null(pos_class)) {
    paste0("Positives (", pos_class, ")")
  } else {
    "Positives"
  }
  df_counts <- data.frame(
    fold = rep(tab$fold, times = 2),
    class = factor(
      rep(c(pos_legend, "Negatives"), each = nrow(tab)),
      levels = c(pos_legend, "Negatives")
    ),
    count = c(tab$positives, tab$negatives)
  )
  max_count <- max(df_counts$count, na.rm = TRUE)
  if (!is.finite(max_count) || max_count <= 0) max_count <- 1
  prop_df <- data.frame(
    fold = tab$fold,
    prop_pos = tab$prop_pos,
    prop_scaled = tab$prop_pos * max_count
  )
  prop_df <- prop_df[is.finite(prop_df$prop_scaled), , drop = FALSE]
  prop_df$series <- "Positive proportion"

  p <- ggplot2::ggplot(df_counts, ggplot2::aes(x = fold, y = count, fill = class)) +
    ggplot2::geom_col(position = ggplot2::position_dodge(width = 0.8),
                      width = 0.7, color = "white") +
    ggplot2::scale_fill_manual(values = setNames(c("steelblue", "tan"),
                                                 c(pos_legend, "Negatives"))) +
    ggplot2::scale_x_continuous(breaks = tab$fold) +
    ggplot2::labs(title = "Fold class balance", x = "Fold", y = "Count", fill = NULL)
  if (nrow(prop_df) > 0) {
    p <- p +
      ggplot2::geom_line(data = prop_df,
                         ggplot2::aes(x = fold, y = prop_scaled,
                                      color = series, linetype = series),
                         inherit.aes = FALSE) +
      ggplot2::geom_point(data = prop_df,
                          ggplot2::aes(x = fold, y = prop_scaled, color = series, shape = series),
                          inherit.aes = FALSE) +
      ggplot2::scale_color_manual(values = c("Positive proportion" = "blue")) +
      ggplot2::scale_linetype_manual(values = c("Positive proportion" = "dashed")) +
      ggplot2::scale_shape_manual(values = c("Positive proportion" = 17)) +
      ggplot2::scale_y_continuous(
        sec.axis = ggplot2::sec_axis(~ . / max_count,
                                     name = "Positive proportion",
                                     labels = function(x) sprintf("%s%%", round(x * 100)))
      )
  }
  p <- p +
    ggplot2::theme_minimal() +
    ggplot2::theme(legend.position = "top") +
    ggplot2::guides(fill = ggplot2::guide_legend(order = 1),
                    color = ggplot2::guide_legend(order = 2),
                    linetype = ggplot2::guide_legend(order = 2),
                    shape = ggplot2::guide_legend(order = 2))
  if (interactive()) print(p)
  invisible(list(
    fold_summary = tab,
    positive_class = pos_class,
    plot = p
  ))
}

#' Plot overlap diagnostics between train/test groups
#'
#' @description
#' Checks whether the same group identifiers appear in both the training and
#' test partitions within each resample. This is designed to detect leakage
#' from grouped or repeated-measures data (for example, the same subject,
#' batch, plate, or study appearing on both sides of a fold) when group-wise
#' splitting is expected.
#'
#' @details
#' For each resample in `fit@splits@indices`, the function counts the number of
#' unique values of `column` in the train and test sets and the size of their
#' intersection. Any non-zero overlap indicates that at least one group appears
#' in both train and test for that resample. The check is metadata-based only:
#' it relies on exact matches of the supplied column and does not inspect
#' features or outcomes. It only checks train vs test within each resample, so
#' it will not detect overlaps across different resamples or other leakage
#' mechanisms. Inconsistent IDs or missing values in the metadata can hide or
#' inflate overlaps. `NA` values are treated as regular identifiers and will
#' count toward overlap if they appear in both partitions. Requires ggplot2.
#'
#' @param fit A `LeakFit` object produced by [fit_resample()]. It must contain
#'   the split indices and the associated metadata in `fit@splits@info$coldata`.
#'   The metadata rows must align with the data used to create the splits.
#' @param column Character scalar naming the metadata column to check (for
#'   example `"subject"` or `"batch"`). The function compares unique values of
#'   this column between train and test within each resample. There is no
#'   default: `NULL` or an unknown column triggers an error. Changing `column`
#'   changes which kind of leakage (subject-level, batch-level, etc.) is tested
#'   and therefore the overlap counts.
#' @return A list returned invisibly with:
#'   \itemize{
#'     \item `overlap_counts`: data.frame with one row per resample and columns
#'       `fold` (resample index in `fit@splits@indices`), `overlap` (unique IDs
#'       shared by train and test), `train` (unique IDs in train), and `test`
#'       (unique IDs in test).
#'     \item `column`: the metadata column name used for the check.
#'     \item `plot`: the ggplot object showing the three count series across folds.
#'   }
#'   The plot is also printed. When any overlap is detected, the plot adds a
#'   warning annotation.
#' @examples
#' set.seed(1)
#' df <- data.frame(
#'   subject = rep(1:6, each = 2),
#'   outcome = rbinom(12, 1, 0.5),
#'   x1 = rnorm(12),
#'   x2 = rnorm(12)
#' )
#' splits <- make_split_plan(df, outcome = "outcome",
#'                       mode = "subject_grouped", group = "subject", v = 3)
#' custom <- list(
#'   glm = list(
#'     fit = function(x, y, task, weights, ...) {
#'       stats::glm(y ~ ., data = as.data.frame(x),
#'                  family = stats::binomial(), weights = weights)
#'     },
#'     predict = function(object, newdata, task, ...) {
#'       as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
#'                                 type = "response"))
#'     }
#'   )
#' )
#' fit <- fit_resample(df, outcome = "outcome", splits = splits,
#'                     learner = "glm", custom_learners = custom,
#'                     metrics = "accuracy", refit = FALSE)
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   out <- plot_overlap_checks(fit, column = "subject")
#'   out$overlap_counts
#' }
#'
#' @export
plot_overlap_checks <- function(fit, column = NULL) {
  stopifnot(inherits(fit, "LeakFit"))
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_overlap_checks().",
         call. = FALSE)
  }
  cd <- fit@splits@info$coldata
  if (is.null(cd) || is.null(column) || !column %in% names(cd)) {
    stop("Column not available in metadata.")
  }
  n <- nrow(cd)
  counts <- lapply(seq_along(fit@splits@indices), function(i) {
    idx <- fit@splits@indices[[i]]
    idx <- .bio_resolve_fold_indices(fit@splits, idx, n = n, data = cd)
    tr <- unique(cd[[column]][idx$train])
    te <- unique(cd[[column]][idx$test])
    data.frame(fold = i, overlap = length(intersect(tr, te)),
               train = length(tr), test = length(te))
  })
  counts <- do.call(rbind, counts)
  plot_df <- data.frame(
    fold = rep(counts$fold, times = 3),
    metric = factor(rep(c("Overlap", "Train unique", "Test unique"),
                        each = nrow(counts)),
                    levels = c("Overlap", "Train unique", "Test unique")),
    count = c(counts$overlap, counts$train, counts$test)
  )
  p <- ggplot2::ggplot(plot_df,
                       ggplot2::aes(x = fold, y = count, color = metric, linetype = metric)) +
    ggplot2::geom_line() +
    ggplot2::geom_point() +
    ggplot2::scale_color_manual(values = c("Overlap" = "red",
                                           "Train unique" = "grey40",
                                           "Test unique" = "grey70")) +
    ggplot2::scale_linetype_manual(values = c("Overlap" = "solid",
                                              "Train unique" = "dashed",
                                              "Test unique" = "dotted")) +
    ggplot2::scale_x_continuous(breaks = counts$fold) +
    ggplot2::labs(title = sprintf("Overlap diagnostics: %s", column),
                  x = "Fold", y = "Count", color = NULL, linetype = NULL) +
    ggplot2::theme_minimal() +
    ggplot2::theme(legend.position = "top")
  if (any(counts$overlap > 0)) {
    y_max <- max(plot_df$count, na.rm = TRUE)
    if (!is.finite(y_max)) y_max <- 1
    p <- p +
      ggplot2::expand_limits(y = y_max * 1.1) +
      ggplot2::annotate("text",
                        x = max(counts$fold),
                        y = y_max * 1.05,
                        label = "WARNING: Overlaps detected!",
                        hjust = 1, vjust = 0,
                        color = "red", fontface = "bold")
  }
  if (interactive()) print(p)
  invisible(list(
    overlap_counts = counts,
    column = column,
    plot = p
  ))
}

#' Plot ACF of test predictions for time-series leakage checks
#'
#' Uses the autocorrelation function of out-of-fold predictions to detect
#' temporal dependence that may indicate leakage. Predictions are ordered by
#' the split time column before computing the ACF. Requires numeric predictions
#' (regression or survival). Requires ggplot2.
#'
#' @param fit LeakFit.
#' @param lag.max maximum lag to show.
#' @return A list with the autocorrelation results, \code{lag.max}, and a ggplot object.
#' @examples
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   set.seed(42)
#'   df <- data.frame(
#'     id = 1:30,
#'     time = seq.Date(as.Date("2020-01-01"), by = "day", length.out = 30),
#'     y = rnorm(30),
#'     x1 = rnorm(30),
#'     x2 = rnorm(30)
#'   )
#'   splits <- make_split_plan(df, outcome = "y", mode = "time_series",
#'                             time = "time", v = 3, progress = FALSE)
#'   custom <- list(
#'     lm = list(
#'       fit = function(x, y, task, weights, ...) {
#'         stats::lm(y ~ ., data = data.frame(y = y, x))
#'       },
#'       predict = function(object, newdata, task, ...) {
#'         as.numeric(stats::predict(object, newdata = as.data.frame(newdata)))
#'       }
#'     )
#'   )
#'   fit <- fit_resample(df, outcome = "y", splits = splits,
#'                       learner = "lm", custom_learners = custom,
#'                       metrics = "rmse", refit = FALSE, seed = 1)
#'   plot_time_acf(fit, lag.max = 10)
#' }
#'
#' @export
plot_time_acf <- function(fit, lag.max = 20) {
  stopifnot(inherits(fit, "LeakFit"))
  all_pred <- do.call(rbind, fit@predictions)
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_time_acf().",
         call. = FALSE)
  }
  if (!"pred" %in% names(all_pred)) {
    stop("Predictions missing 'pred' column for ACF plotting.", call. = FALSE)
  }
  pred <- all_pred$pred
  if (!is.numeric(pred)) {
    stop("plot_time_acf requires numeric predictions (regression or survival).", call. = FALSE)
  }
  if (!"id" %in% names(all_pred)) {
    stop("Predictions are missing sample ids; time ordering unavailable.", call. = FALSE)
  }
  time_col <- fit@splits@info$time %||% NULL
  coldata <- fit@splits@info$coldata %||% NULL
  if (is.null(coldata)) {
    stop("plot_time_acf requires split metadata with time values.", call. = FALSE)
  }
  coldata <- as.data.frame(coldata, check.names = FALSE)
  if (is.null(time_col) || !time_col %in% names(coldata)) {
    stop("plot_time_acf requires a time column in split metadata.", call. = FALSE)
  }

  ids_chr <- as.character(all_pred$id)
  time_vals <- NULL
  rn <- rownames(coldata)
  if (!is.null(rn) && all(ids_chr %in% rn)) {
    time_vals <- coldata[match(ids_chr, rn), time_col]
  } else if ("row_id" %in% names(coldata)) {
    rid <- as.character(coldata[["row_id"]])
    if (!anyDuplicated(rid) && all(ids_chr %in% rid)) {
      time_vals <- coldata[match(ids_chr, rid), time_col]
    }
  } else if (!is.null(fit@info$sample_ids) &&
             length(fit@info$sample_ids) == nrow(coldata)) {
    sample_ids <- as.character(fit@info$sample_ids)
    idx <- match(ids_chr, sample_ids)
    if (all(!is.na(idx))) time_vals <- coldata[idx, time_col]
  } else {
    ids_int <- suppressWarnings(as.integer(ids_chr))
    if (all(!is.na(ids_int)) && max(ids_int, na.rm = TRUE) <= nrow(coldata)) {
      time_vals <- coldata[ids_int, time_col]
    }
  }
  if (is.null(time_vals) || length(time_vals) != length(pred)) {
    stop("plot_time_acf could not align time metadata to predictions.", call. = FALSE)
  }
  if (!is.numeric(time_vals) && !inherits(time_vals, c("POSIXct", "Date"))) {
    stop("plot_time_acf requires numeric, Date, or POSIXct time values.", call. = FALSE)
  }

  ok_time <- !is.na(time_vals)
  if (!all(ok_time)) {
    warning("plot_time_acf dropped predictions with missing time values.", call. = FALSE)
  }
  pred <- pred[ok_time]
  time_vals <- time_vals[ok_time]
  ok_pred <- is.finite(pred)
  pred <- pred[ok_pred]
  time_vals <- time_vals[ok_pred]
  if (length(pred) < 2) {
    stop("Not enough finite predictions for ACF plotting.", call. = FALSE)
  }
  ord <- order(time_vals, seq_along(time_vals), na.last = TRUE)
  pred <- pred[ord]
  acf_res <- stats::acf(pred, lag.max = lag.max, plot = FALSE)
  lag_vals <- as.numeric(acf_res$lag)
  acf_vals <- as.numeric(acf_res$acf)
  df <- data.frame(lag = lag_vals, acf = acf_vals)
  conf <- if (is.finite(acf_res$n.used) && acf_res$n.used > 0) {
    1.96 / sqrt(acf_res$n.used)
  } else {
    NA_real_
  }

  p <- ggplot2::ggplot(df, ggplot2::aes(x = lag, y = acf)) +
    ggplot2::geom_hline(yintercept = 0, color = "grey50") +
    ggplot2::geom_segment(ggplot2::aes(xend = lag, yend = 0), color = "steelblue") +
    ggplot2::geom_point(color = "steelblue") +
    ggplot2::labs(title = "Prediction autocorrelation", x = "Lag", y = "ACF") +
    ggplot2::theme_minimal()
  if (is.finite(conf)) {
    p <- p +
      ggplot2::geom_hline(yintercept = c(conf, -conf),
                          color = "red", linetype = "dashed")
  }
  if (interactive()) print(p)
  invisible(list(acf = acf_res, lag.max = lag.max, plot = p))
}

#' Plot calibration curve for binomial predictions
#'
#' Visualizes observed outcome rates versus predicted probabilities across
#' bins to diagnose calibration (binomial tasks only). Requires ggplot2.
#'
#' @param fit LeakFit.
#' @param bins Number of probability bins to use.
#' @param min_bin_n Minimum samples per bin shown in the plot.
#' @param learner Optional character scalar. When predictions include multiple
#'   learners, selects the learner to summarize.
#' @return A list containing the calibration curve, metrics, and a ggplot object.
#' @examples
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   set.seed(42)
#'   df <- data.frame(
#'     subject = rep(1:15, each = 2),
#'     outcome = factor(rep(c(0, 1), 15)),
#'     x1 = rnorm(30),
#'     x2 = rnorm(30)
#'   )
#'   splits <- make_split_plan(df, outcome = "outcome",
#'                             mode = "subject_grouped", group = "subject",
#'                             v = 3, progress = FALSE)
#'   custom <- list(
#'     glm = list(
#'       fit = function(x, y, task, weights, ...) {
#'         stats::glm(y ~ ., data = as.data.frame(x),
#'                    family = stats::binomial(), weights = weights)
#'       },
#'       predict = function(object, newdata, task, ...) {
#'         as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
#'                                   type = "response"))
#'       }
#'     )
#'   )
#'   fit <- fit_resample(df, outcome = "outcome", splits = splits,
#'                       learner = "glm", custom_learners = custom,
#'                       metrics = "auc", refit = FALSE, seed = 1)
#'   plot_calibration(fit, bins = 5)
#' }
#'
#' @export
plot_calibration <- function(fit, bins = 10, min_bin_n = 5, learner = NULL) {
  stopifnot(inherits(fit, "LeakFit"))
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_calibration().",
         call. = FALSE)
  }
  cal <- calibration_summary(fit, bins = bins, min_bin_n = min_bin_n, learner = learner)
  df <- cal$curve
  df <- df[is.finite(df$pred_mean) & is.finite(df$obs_rate), , drop = FALSE]
  df_plot <- df[df$n >= min_bin_n, , drop = FALSE]
  if (!nrow(df_plot)) {
    stop("Not enough data for calibration plotting.", call. = FALSE)
  }
  p <- ggplot2::ggplot(df_plot, ggplot2::aes(x = pred_mean, y = obs_rate)) +
    ggplot2::geom_abline(slope = 1, intercept = 0, linetype = "dashed",
                         color = "grey50") +
    ggplot2::geom_line(color = "steelblue") +
    ggplot2::geom_point(ggplot2::aes(size = n), color = "steelblue") +
    ggplot2::scale_size_continuous(range = c(2, 6)) +
    ggplot2::labs(title = "Calibration curve",
                  x = "Mean predicted probability",
                  y = "Observed event rate",
                  size = "Bin n") +
    ggplot2::theme_minimal()
  if (interactive()) print(p)
  invisible(list(curve = df, metrics = cal$metrics, plot = p))
}

#' Plot confounder sensitivity
#'
#' Shows performance metrics across confounder strata to assess sensitivity to
#' batch/study effects. Requires ggplot2.
#'
#' @param fit LeakFit.
#' @param confounders Character vector of columns in `coldata` to evaluate.
#' @param metric Metric name to compute within each stratum.
#' @param min_n Minimum samples per stratum to display.
#' @param coldata Optional data.frame of sample metadata.
#' @param numeric_bins Number of quantile bins for numeric confounders.
#' @param learner Optional character scalar. When predictions include multiple
#'   learners, selects the learner to summarize.
#' @return A list containing the sensitivity table and a ggplot object.
#' @examples
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#'   set.seed(42)
#'   df <- data.frame(
#'     subject = rep(1:15, each = 2),
#'     outcome = factor(rep(c(0, 1), 15)),
#'     batch = factor(rep(c("A", "B", "C"), 10)),
#'     x1 = rnorm(30),
#'     x2 = rnorm(30)
#'   )
#'   splits <- make_split_plan(df, outcome = "outcome",
#'                             mode = "subject_grouped", group = "subject",
#'                             v = 3, progress = FALSE)
#'   custom <- list(
#'     glm = list(
#'       fit = function(x, y, task, weights, ...) {
#'         stats::glm(y ~ ., data = as.data.frame(x),
#'                    family = stats::binomial(), weights = weights)
#'       },
#'       predict = function(object, newdata, task, ...) {
#'         as.numeric(stats::predict(object, newdata = as.data.frame(newdata),
#'                                   type = "response"))
#'       }
#'     )
#'   )
#'   fit <- fit_resample(df, outcome = "outcome", splits = splits,
#'                       learner = "glm", custom_learners = custom,
#'                       metrics = "auc", refit = FALSE, seed = 1)
#'   plot_confounder_sensitivity(fit, confounders = "batch", coldata = df)
#' }
#'
#' @export
plot_confounder_sensitivity <- function(fit, confounders = NULL, metric = NULL,
                                        min_n = 10, coldata = NULL, numeric_bins = 4,
                                        learner = NULL) {
  stopifnot(inherits(fit, "LeakFit"))
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plotting. Install it to use plot_confounder_sensitivity().",
         call. = FALSE)
  }
  df <- confounder_sensitivity(fit, confounders = confounders, metric = metric,
                               min_n = min_n, coldata = coldata,
                               numeric_bins = numeric_bins, learner = learner)
  if (is.null(df) || !nrow(df)) {
    stop("No confounder sensitivity results available for plotting.", call. = FALSE)
  }
  df_plot <- df[is.finite(df$value) & df$n >= min_n, , drop = FALSE]
  if (!nrow(df_plot)) {
    stop("No strata meet the minimum size for plotting.", call. = FALSE)
  }
  direction <- df_plot$direction[1] %||% "higher"
  p <- ggplot2::ggplot(df_plot, ggplot2::aes(x = level, y = value)) +
    ggplot2::geom_col(fill = "steelblue") +
    ggplot2::facet_wrap(~ confounder, scales = "free_x") +
    ggplot2::coord_flip() +
    ggplot2::labs(title = "Confounder sensitivity",
                  subtitle = sprintf("%s (better is %s)", df_plot$metric[1], direction),
                  x = NULL, y = "Metric value") +
    ggplot2::theme_minimal()
  if (interactive()) print(p)
  invisible(list(data = df, plot = p))
}

Try the bioLeak package in your browser

Any scripts or data that you put into this service are public.

bioLeak documentation built on March 6, 2026, 1:06 a.m.