R/utils.R

Defines functions prep_w warning_burden prepare_bg basic_checks wrowmean_vector rep_each case_p1 summarize_strategy head_list align_pred reorganize_list abind1 check_convergence get_sigma get_vz exact_Z wcolMeans rep_rows

#' Fast Row Subsetting
#'
#' Internal function used to row-subset data.frames.
#' Brings a massive speed-up for data.frames. All other classes (tibble, data.table,
#' matrix) are subsetted in the usual way.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A matrix-like object.
#' @param i Logical or integer vector of rows to pick.
#' @returns Subsetted version of `x`.
rep_rows <- function(x, i) {
  if (!(all(class(x) == "data.frame"))) {
    return(x[i, , drop = FALSE]) # matrix, tibble, data.table, ...
  }
  # data.frame
  out <- lapply(x, function(z) if (length(dim(z)) != 2L) z[i] else z[i, , drop = FALSE])
  attr(out, "row.names") <- .set_row_names(length(i))
  class(out) <- "data.frame"
  out
}

#' Weighted Version of colMeans()
#'
#' Internal function used to calculate column-wise weighted means.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A matrix-like object.
#' @param w Optional case weights.
#' @returns A (1 x ncol(x)) matrix of column means.
wcolMeans <- function(x, w = NULL, ...) {
  x <- as.matrix(x)
  out <- if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w)
  t.default(out)
}

#' All on-off Vectors
#'
#' Internal function that creates matrix of all on-off vectors of length `p`.
#' Faster than as.matrix(do.call(expand.grid, replicate(p, 0:1, simplify = FALSE)))[, p:1]
#' Note that we currently do not return a logical matrix because it goes into
#' matrix multiplications in `kernelshap()`.
#'
#' @noRd
#' @keywords internal
#'
#' @param p Number of features.
#' @param feature_names Feature names.
#' @returns An logical matrix of all on-off vectors of length `p`.
exact_Z <- function(p, feature_names) {
  if (p < 2L) {
    stop("p must be at least 2 if exact = TRUE.")
  }
  m <- 2^p
  M <- seq_len(m) - 1L
  encoded <- as.logical(intToBits(M))
  dim(encoded) <- c(32L, m)
  Z <- t(encoded[p:1L, , drop = FALSE])
  colnames(Z) <- feature_names
  return(Z)
}

#' Masked Predict
#'
#' Internal function.
#' For each on-off vector (rows in `Z`), the (weighted) average prediction is returned.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams kernelshap
#' @param x Row to be explained.
#' @param bg_rep Background data stacked m times.
#' @param Z_rep A logical ((m * bg_n) x p) matrix with on-off values.
#' @param w A vector with case weights (of the same length as the unstacked
#'   background data).
#' @param bg_n Size of background dataset (unstacked).
#' @returns A (m x K) matrix with vz values.
get_vz <- function(x, bg_rep, Z_rep, object, pred_fun, w, bg_n, ...) {
  for (v in colnames(Z_rep)) {
    s <- Z_rep[, v]
    if (is.matrix(x)) {
      bg_rep[s, v] <- x[, v]
    } else {
      bg_rep[[v]][s] <- x[[v]]
    }
  }

  preds <- align_pred(pred_fun(object, bg_rep, ...))

  # Aggregate (distinguishing fast 1-dim case)
  m <- nrow(Z_rep) %/% bg_n
  if (ncol(preds) == 1L) {
    return(wrowmean_vector(preds, ngroups = m, w = w))
  }
  g <- rep_each(m, each = bg_n)
  if (is.null(w)) {
    return(rowsum(preds, group = g, reorder = FALSE) / bg_n)
  }
  rowsum(preds * w, group = g, reorder = FALSE) / sum(w)
}

#' Calculates standard errors over array
#'
#' Calculates standard errors over (n x p x K) array
#' of n estimates of (p x K) matrices.
#'
#' @noRd
#' @keywords internal
#'
#' @param est Array of dimension (n x p x K) with n (p x K) dimensional estimates.
#' @returns A (p x K) matrix with standard errors.
get_sigma <- function(est) {
  n <- dim(est)[1L]
  if (n < 2L) {
    stop("Cannot calculate standard errors from single estimate.")
  }
  # No Bessel correction as each value is already an estimate of multiple observations
  variance_no_bessel <- apply(est, 2L:3L, FUN = stats::var) * (n - 1) / n
  return(sqrt(variance_no_bessel / n))
}

#' Convergence criterion
#'
#' Checks if the largest standard error (per output-dimension) is not larger than
#' `tol` times the range of the current SHAP values `beta`.
#'
#' @noRd
#' @keywords internal
#'
#' @param beta A (p x K) matrix of SHAP values.
#' @param beta A (p x K) matrix of standard errors.
#' @param tol Tolerance, usually around 0.01.
#' @returns A logical of length 1.
check_convergence <- function(beta, sigma, tol) {
  beta_range <- apply(beta, 2L, FUN = function(z) diff(range(z)))
  max_standard_error <- apply(sigma, 2L, FUN = max)
  converged <- max_standard_error <= tol * beta_range
  return(all(converged))
}

#' Combine Matrices
#'
#' Binds list of matrices along new first axis.
#'
#' @noRd
#' @keywords internal
#'
#' @param a List of n (p x K) matrices.
#' @returns A (n x p x K) array.
abind1 <- function(a) {
  out <- array(
    dim = c(length(a), dim(a[[1L]])),
    dimnames = c(list(NULL), dimnames(a[[1L]]))
  )
  for (i in seq_along(a)) {
    out[i, , ] <- a[[i]]
  }
  out
}

#' Reorganize List
#'
#' Internal function that turns list of n (p x K) matrices into list of K (n x p)
#' matrices. Reduce if K = 1.
#'
#' @noRd
#' @keywords internal
#'
#' @param alist List of n (p x K) matrices.
#' @returns List of K (n x p) matrices.
reorganize_list <- function(alist) {
  if (!is.list(alist)) {
    stop("alist must be a list")
  }
  out <- asplit(abind1(alist), MARGIN = 3L)
  if (length(out) == 1L) {
    return(as.matrix(out[[1L]]))
  }
  lapply(out, as.matrix)
}

#' Aligns Predictions
#'
#' Turns predictions into matrix.
#'
#' @noRd
#' @keywords internal
#'
#' @param x Object representing model predictions.
#' @returns Like `x`, but converted to matrix.
align_pred <- function(x) {
  if (is.data.frame(x) && ncol(x) == 1L) {
    x <- x[[1L]]
  }
  if (!is.matrix(x)) {
    x <- as.matrix(x)
  }
  if (!is.numeric(x) && !is.logical(x)) {
    stop("Predictions must be numeric!")
  }
  return(x)
}

#' Head of List Elements
#'
#' Internal function that returns the top n rows of each element in the input list.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A list or a matrix-like.
#' @param n Number of rows to show.
#' @returns List of first rows of each element in the input.
head_list <- function(x, n = 6L) {
  if (!is.list(x)) utils::head(x, n) else lapply(x, utils::head, n)
}

# Summarize details about the chosen algorithm (exact, hybrid, sampling)
summarize_strategy <- function(p, exact, deg) {
  if (exact || trunc(p / 2) == deg) {
    txt <- "Exact Kernel SHAP values"
    if (!exact) {
      txt <- paste(txt, "by the hybrid approach")
    }
    return(txt)
  }
  if (deg == 0L) {
    return("Kernel SHAP values by iterative sampling")
  }
  paste("Kernel SHAP values by the hybrid strategy of degree", deg)
}

# Case p = 1 returns exact Shapley values
case_p1 <- function(n, feature_names, v0, v1, X, verbose) {
  txt <- "Exact Shapley values (p = 1)"
  if (verbose) {
    message(txt)
  }
  S <- v1 - v0[rep(1L, n), , drop = FALSE] #  (n x K)
  if (ncol(v1) > 1L) {
    S <- lapply(
      asplit(S, MARGIN = 2L), function(M) {
        as.matrix(M, dimnames = list(NULL, feature_names))
      }
    )
  } else {
    colnames(S) <- feature_names
  }
  out <- list(
    S = S,
    X = X,
    baseline = as.vector(v0),
    bg_X = NULL,
    bg_w = NULL,
    m_exact = 0L,
    prop_exact = 1,
    exact = TRUE,
    txt = txt,
    predictions = v1,
    algorithm = "kernelshap"
  )
  class(out) <- "kernelshap"
  out
}

#' Fast Index Generation (from {hstats})
#'
#' For not too small m, much faster than `rep(seq_len(m), each = each)`.
#'
#' @noRd
#' @keywords internal
#'
#' @param m Integer. See `each`.
#' @param each Integer. How many times should each value in `1:m` be repeated?
#' @returns Like `x`, but converted to matrix.
#' @examples
#' rep_each(10, 2)
#' rep(1:10, each = 2) # Dito
rep_each <- function(m, each) {
  out <- .col(dim = c(each, m))
  dim(out) <- NULL
  out
}

#' Grouped Means for Single-Column Matrices (adapted from {hstats})
#'
#' Grouped means for matrix with single column over fixed-length groups.
#'
#' @noRd
#' @keywords internal
#'
#' @param x Matrix with one column.
#' @param ngroups Number of subsequent, equals sized groups.
#' @param w Optional vector of case weights of length `NROW(x) / ngroups`.
#' @returns Matrix with one column.
wrowmean_vector <- function(x, ngroups = 1L, w = NULL) {
  if (ncol(x) != 1L) {
    stop("x must have a single column")
  }
  nm <- colnames(x)
  dim(x) <- c(length(x) %/% ngroups, ngroups)
  out <- if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w)
  dim(out) <- c(ngroups, 1L)
  if (!is.null(nm)) {
    colnames(out) <- nm
  }
  out
}

#' Basic Input Checks
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams kernelshap
#'
#' @returns TRUE or an error
basic_checks <- function(X, feature_names, pred_fun) {
  stopifnot(
    is.matrix(X) || is.data.frame(X),
    dim(X) >= 1L,
    length(feature_names) >= 1L,
    all(feature_names %in% colnames(X)),
    is.function(pred_fun)
  )
  TRUE
}

#' Prepare Background Data
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams kernelshap
#'
#' @returns List with bg_X and bg_w.
prepare_bg <- function(X, bg_X, bg_n, bg_w, verbose) {
  n <- nrow(X)
  if (is.null(bg_X)) {
    if (n <= bg_n) { # No subsampling required
      if (n < min(20L, bg_n)) {
        stop("X is too small to act as background data. Please specify 'bg_X'.")
      }
      if (n < min(50L, bg_n)) {
        warning("X is quite small to act as background data. Consider specifying a larger 'bg_X'.")
      }
      bg_X <- X
    } else { # Subsampling
      if (verbose) {
        message("Sampling ", bg_n, " rows from X as background data.")
      }
      ix <- sample(n, bg_n)
      bg_X <- X[ix, , drop = FALSE]
      if (!is.null(bg_w)) {
        stopifnot(length(bg_w) == n)
        bg_w <- bg_w[ix]
      }
    }
  } else {
    stopifnot(
      is.matrix(bg_X) || is.data.frame(bg_X),
      is.matrix(X) == is.matrix(bg_X),
      nrow(bg_X) >= 1L,
      all(colnames(X) %in% colnames(bg_X))
    )
    bg_X <- bg_X[, colnames(X), drop = FALSE]
  }

  if (!is.null(bg_w)) {
    bg_w <- prep_w(bg_w, bg_n = nrow(bg_X))
  }

  return(list(bg_X = bg_X, bg_w = bg_w))
}

#' Warning on Slow Computations
#'
#' @noRd
#' @keywords internal
#'
#' @param m Number of on-off vectors.
#' @param bg_n Number of rows in the background data.
#'
#' @returns TRUE.
warning_burden <- function(m, bg_n) {
  warning(
    "\nPredictions on large data sets with ", m, "x", bg_n,
    " observations are being done.\n",
    "Consider reducing the computational burden (e.g. use smaller X_bg)"
  )
  TRUE
}

#' Prepare Case Weights
#'
#' @noRd
#' @keywords internal
#'
#' @param w Vector of case weights.
#' @param bg_n Number of rows in the background data.
#'
#' @returns TRUE or an error.
prep_w <- function(w, bg_n) {
  stopifnot(
    length(w) == bg_n,
    all(w >= 0),
    !all(w == 0)
  )
  if (!is.double(w)) as.double(w) else w
}

Try the kernelshap package in your browser

Any scripts or data that you put into this service are public.

kernelshap documentation built on Nov. 5, 2025, 6:16 p.m.