R/stability_index.R

Defines functions stability_index

Documented in stability_index

#' Stability Index for Model Predictions
#'
#' Computes a Stability Index that quantifies the consistency of machine
#' learning model predictions across multiple runs or resamples. A stability
#' index of 1 indicates perfectly consistent predictions, while values
#' closer to 0 indicate high variability across runs.
#'
#' The index is calculated by comparing the mean per-observation variance
#' across runs to the overall variance of all predictions. Low
#' per-observation variance relative to overall variance indicates that
#' the model produces consistent results regardless of the specific
#' training run or resample.
#'
#' @param predictions_matrix A numeric matrix or data.frame where each row
#'   represents an observation and each column represents predictions from
#'   a single model run or resample. Must contain at least two columns and
#'   no missing values.
#'
#' @return A numeric scalar between 0 and 1, where 1 indicates perfect
#'   stability (identical predictions across all runs) and values near 0
#'   indicate high instability.
#'
#' @examples
#' # Simulate predictions from 5 model runs for 100 observations
#' set.seed(42)
#' base_predictions <- rnorm(100)
#' predictions <- matrix(
#'   rep(base_predictions, 5) + rnorm(500, sd = 0.1),
#'   ncol = 5
#' )
#' stability_index(predictions)
#'
#' # Perfectly stable predictions yield an index of 1
#' stable_preds <- matrix(rep(1:10, 3), ncol = 3)
#' stability_index(stable_preds)
#'
#' @importFrom stats var
#' @export
stability_index <- function(predictions_matrix) {
  if (!is.matrix(predictions_matrix) && !is.data.frame(predictions_matrix)) {
    stop("'predictions_matrix' must be a matrix or data.frame.", call. = FALSE)
  }

  predictions_matrix <- as.matrix(predictions_matrix)

  if (!is.numeric(predictions_matrix)) {
    stop("'predictions_matrix' must contain numeric values.", call. = FALSE)
  }

  if (ncol(predictions_matrix) < 2L) {
    stop(
      "At least two columns (runs) are required to compute stability.",
      call. = FALSE
    )
  }

  if (anyNA(predictions_matrix)) {
    stop("'predictions_matrix' must not contain NA values.", call. = FALSE)
  }

  row_variances <- apply(predictions_matrix, 1L, var)
  mean_row_var <- mean(row_variances)
  overall_var <- var(as.vector(predictions_matrix))

  if (overall_var == 0) {
    return(1)
  }

  stability <- 1 - mean_row_var / overall_var
  stability <- max(0, min(1, stability))

  return(stability)
}

Try the TrustworthyMLR package in your browser

Any scripts or data that you put into this service are public.

TrustworthyMLR documentation built on Feb. 20, 2026, 5:09 p.m.