R/plot_stability.R

Defines functions plot_stability

Documented in plot_stability

#' Plot Stability of Model Predictions
#'
#' Creates a visualization showing the variability of model predictions
#' across multiple runs. This helps identify whether instability is
#' uniform across the dataset or concentrated on specific observations.
#'
#' The plot displays the mean prediction for each observation with error
#' bars representing the range (minimum and maximum) or standard deviation
#' of predictions across runs.
#'
#' @param predictions_matrix A numeric matrix or data.frame where each row
#'   represents an observation and each column represents predictions from
#'   a single model run or resample.
#' @param type Character string indicating what the error bars represent.
#'   Either \code{"range"} (default) or \code{"sd"} (standard deviation).
#' @param ... Additional arguments passed to \code{\link[graphics]{plot}}.
#'
#' @return No return value, called for side effects (plotting).
#'
#' @examples
#' # Simulate predictions from 5 model runs
#' set.seed(42)
#' base_predictions <- sort(rnorm(50))
#' predictions <- matrix(
#'     rep(base_predictions, 5) + rnorm(250, sd = 0.2),
#'     ncol = 5
#' )
#'
#' plot_stability(predictions, main = "Model Prediction Stability")
#'
#' @importFrom graphics plot segments points
#' @importFrom stats sd
#' @export
plot_stability <- function(predictions_matrix, type = c("range", "sd"), ...) {
    if (!is.matrix(predictions_matrix) && !is.data.frame(predictions_matrix)) {
        stop("'predictions_matrix' must be a matrix or data.frame.", call. = FALSE)
    }

    type <- match.arg(type)
    mat <- as.matrix(predictions_matrix)
    n_obs <- nrow(mat)

    means <- rowMeans(mat)

    if (type == "range") {
        low <- apply(mat, 1L, min)
        high <- apply(mat, 1L, max)
        ylab_text <- "Prediction Range"
    } else {
        sds <- apply(mat, 1L, sd)
        low <- means - sds
        high <- means + sds
        ylab_text <- "Mean Prediction +/- SD"
    }

    # Order by mean for clearer visualization
    ord <- order(means)
    means <- means[ord]
    low <- low[ord]
    high <- high[ord]

    plot(
        seq_len(n_obs), means,
        ylim = range(c(low, high)),
        xlab = "Observations (Sorted by Mean Prediction)",
        ylab = ylab_text,
        pch = 19, cex = 0.7,
        ...
    )

    segments(seq_len(n_obs), low, seq_len(n_obs), high, col = "gray70")
    points(seq_len(n_obs), means, pch = 19, cex = 0.7)

    invisible(NULL)
}

Try the TrustworthyMLR package in your browser

Any scripts or data that you put into this service are public.

TrustworthyMLR documentation built on Feb. 20, 2026, 5:09 p.m.