Nothing
# Kernel SHAP algorithm for a single row x
# If exact, a single call to predict() is necessary.
# If sampling is involved, we need at least two additional calls to predict().
kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, deg,
paired, m, tol, max_iter, v0, precalc, ...) {
p <- length(feature_names)
# Calculate A_exact and b_exact
if (exact || deg >= 1L) {
A_exact <- precalc[["A"]] # (p x p)
bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p)
Z <- precalc[["Z"]] # (m_ex x p)
m_exact <- nrow(Z)
v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K)
# Most expensive part
vz <- get_vz( # (m_ex x K)
X = rep_rows(x, rep.int(1L, nrow(bg_X_exact))), # (m_ex*n_bg x p)
bg = bg_X_exact, # (m_ex*n_bg x p)
Z = Z, # (m_ex x p)
object = object,
pred_fun = pred_fun,
w = bg_w,
...
)
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K)
# Some of the hybrid cases are exact as well
if (exact || trunc(p / 2) == deg) {
beta <- solver(A_exact, b_exact, constraint = v1 - v0) # (p x K)
return(list(beta = beta, sigma = 0 * beta, n_iter = 1L, converged = TRUE))
}
}
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
X <- rep_rows(x, rep.int(1L, nrow(bg_X_m))) # (m*n_bg x p)
v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K)
est_m = list()
converged <- FALSE
n_iter <- 0L
A_sum <- matrix( # (p x p)
0, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
)
b_sum <- matrix( # (p x K)
0, nrow = p, ncol = ncol(v0), dimnames = list(feature_names, colnames(v1))
)
if (deg == 0L) {
A_exact <- A_sum
b_exact <- b_sum
}
while(!isTRUE(converged) && n_iter < max_iter) {
n_iter <- n_iter + 1L
input <- input_sampling(
p = p, m = m, deg = deg, paired = paired, feature_names = feature_names
)
Z <- input[["Z"]]
# Expensive # (m x K)
vz <- get_vz(
X = X, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
)
# The sum of weights of A_exact and input[["A"]] is 1, same for b
A_temp <- A_exact + input[["A"]] # (p x p)
b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K)
A_sum <- A_sum + A_temp # (p x p)
b_sum <- b_sum + b_temp # (p x K)
# Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0.
# The additional constraint beta_0 = v_0 is dealt via offset
est_m[[n_iter]] <- solver(A_temp, b_temp, constraint = v1 - v0) # (p x K)
# Covariance calculation would fail in the first iteration
if (n_iter >= 2L) {
beta_n <- solver(A_sum / n_iter, b_sum / n_iter, constraint = v1 - v0) # (p x K)
sigma_n <- get_sigma(est_m, iter = n_iter) # (p x K)
converged <- all(conv_crit(sigma_n, beta_n) < tol)
}
}
list(beta = beta_n, sigma = sigma_n, n_iter = n_iter, converged = converged)
}
# Regression coefficients given sum(beta) = constraint
# A: (p x p), b: (p x k), constraint: (1 x K)
solver <- function(A, b, constraint) {
p <- ncol(A)
Ainv <- ginv(A)
dimnames(Ainv) <- dimnames(A)
s <- (matrix(colSums(Ainv %*% b), nrow = 1L) - constraint) / sum(Ainv) # (1 x K)
Ainv %*% (b - s[rep.int(1L, p), , drop = FALSE]) # (p x K)
}
ginv <- function (X, tol = sqrt(.Machine$double.eps)) {
##
## Based on an original version in package MASS 7.3.56
## (with permission)
##
if (length(dim(X)) > 2L || !(is.numeric(X) || is.complex(X))) {
stop("'X' must be a numeric or complex matrix")
}
if (!is.matrix(X)) {
X <- as.matrix(X)
}
Xsvd <- svd(X)
if (is.complex(X)) {
Xsvd$u <- Conj(Xsvd$u)
}
Positive <- Xsvd$d > max(tol * Xsvd$d[1L], 0)
if (all(Positive)) {
Xsvd$v %*% (1 / Xsvd$d * t(Xsvd$u))
} else if (!any(Positive)) {
array(0, dim(X)[2L:1L])
} else {
Xsvd$v[, Positive, drop = FALSE] %*%
((1 / Xsvd$d[Positive]) * t(Xsvd$u[, Positive, drop = FALSE]))
}
}
# Draw m binary vectors z of length p with sum(z) distributed according
# to Kernel SHAP weights -> (m x p) matrix.
# The argument S can be used to restrict the range of sum(z).
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
# First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
probs <- kernel_weights(p, S = S)
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
# Then, conditional on that number, set random positions of z to 1
# Original, unvectorized code
# out <- vapply(
# N,
# function(z) {out <- numeric(p); out[sample(1:p, z)] <- 1; out},
# FUN.VALUE = numeric(p)
# )
# t(out)
# Vectorized by Mathias Ambuehl
out <- rep(rep.int(0:1, m), as.vector(rbind(p - N, N)))
dim(out) <- c(p, m)
ord <- order(col(out), sample.int(m * p))
out[] <- out[ord]
rownames(out) <- feature_names
t(out)
}
# Calculate standard error from list of m estimates
get_sigma <- function(est, iter) {
apply(abind1(est), 3L, FUN = function(Y) sqrt(diag(stats::cov(Y)) / iter))
}
# Convergence criterion
conv_crit <- function(sig, bet) {
if (any(dim(sig) != dim(bet))) {
stop("sig must have same dimension as bet")
}
apply(sig, 2L, FUN = max) / apply(bet, 2L, FUN = function(z) diff(range(z)))
}
# Provides random input for SHAP sampling:
# - Z: Matrix with m on-off vectors z with sum(z) following Kernel weight distribution.
# - w: Vector (1/m, 1/m, ...) of length m (if pure sampling)
# - A: Matrix A = Z'wZ
# The weights are constant (Kernel weights have been used to draw the z vectors).
#
# If deg > 0, vectors z with sum(z) restricted to [deg+1, p-deg-1] are sampled.
# This case is used in combination with input_partly_hybrid(). Consequently, sum(w) < 1.
input_sampling <- function(p, m, deg, paired, feature_names) {
if (p < 2L * deg + 2L) {
stop("p must be >=2*deg + 2")
}
S <- (deg + 1L):(p - deg - 1L)
Z <- sample_Z(
p = p, m = if (paired) m / 2 else m, feature_names = feature_names, S = S
)
if (paired) {
Z <- rbind(Z, 1 - Z)
}
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
w <- w_total / m
list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w)
}
# Functions required only for handling (partly) exact cases
# Provides fixed input for the exact case:
# - Z: Matrix with all 2^p-2 on-off vectors z
# - w: Vector with row weights of Z ensuring that the distribution of sum(z) matches
# the SHAP kernel distribution
# - A: Exact matrix A = Z'wZ
input_exact <- function(p, feature_names) {
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = FALSE)
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))
}
#' Exact Matrix A
#'
#' Internal function that calculates exact A.
#' Notice the difference to the off-diagnonals in the Supplement of
#' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
#' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
#'
#' @noRd
#' @keywords internal
#'
#' @param p Number of features.
#' @param feature_names Feature names.
#' @returns A (p x p) matrix.
exact_A <- function(p, feature_names) {
S <- 1:(p - 1L)
c_pr <- S * (S - 1) / p / (p - 1)
off_diag <- sum(kernel_weights(p) * c_pr)
A <- matrix(
off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
)
diag(A) <- 0.5
A
}
# List all length p vectors z with sum(z) in {k, p - k}
partly_exact_Z <- function(p, k, feature_names) {
if (k < 1L) {
stop("k must be at least 1")
}
if (p < 2L * k) {
stop("p must be >=2*k")
}
if (k == 1L) {
Z <- diag(p)
} else {
Z <- t(
utils::combn(seq_len(p), k, FUN = function(z) {x <- numeric(p); x[z] <- 1; x})
)
}
if (p != 2L * k) {
Z <- rbind(Z, 1 - Z)
}
colnames(Z) <- feature_names
Z
}
# Create Z, w, A for vectors z with sum(z) in {k, p-k} for k in {1, ..., deg}.
# The total weights do not sum to one, except in the special (exact) case deg=p-deg.
# (The remaining weight will be added via input_sampling(p, deg=deg)).
# Note that for a given k, the weights are constant.
input_partly_exact <- function(p, deg, feature_names) {
if (deg < 1L) {
stop("deg must be at least 1")
}
if (p < 2L * deg) {
stop("p must be >=2*deg")
}
kw <- kernel_weights(p)
Z <- w <- vector("list", deg)
for (k in seq_len(deg)) {
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
n <- nrow(Z[[k]])
w_tot <- kw[k] * (2 - (p == 2L * k))
w[[k]] <- rep.int(w_tot / n, n)
}
w <- unlist(w, recursive = FALSE, use.names = FALSE)
Z <- do.call(rbind, Z)
list(Z = Z, w = w, A = crossprod(Z, w * Z))
}
# Kernel weights normalized to a non-empty subset S of {1, ..., p-1}
kernel_weights <- function(p, S = seq_len(p - 1L)) {
probs <- (p - 1L) / (choose(p, S) * S * (p - S))
probs / sum(probs)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.