Nothing
#' Collapse SHAP values
#'
#' This function sums up SHAP values (or SHAP interaction values) of feature groups.
#' Typical application: SHAP values have been generated by a model with one or multiple
#' one-hot encoded variables, but the explanations should be done using the
#' original factor.
#'
#' @param S Either a (n x p) matrix of SHAP values or a (n x p x p) array of SHAP
#' interaction values.
#' @param collapse A named list of character vectors. Each vector specifies the
#' feature names whose SHAP values need to be summed up.
#' The names determine the resulting collapsed column/dimension names.
#' @param ... Currently unused.
#' @returns A matrix of SHAP values, or an array of SHAP interaction values.
#' @export
#' @examples
#' S <- cbind(
#' x = c(0.1, 0.1, 0.1),
#' `age low` = c(0.2, -0.1, 0.1),
#' `age mid` = c(0, 0.2, -0.2),
#' `age high` = c(1, -1, 0)
#' )
#' collapse <- list(age = c("age low", "age mid", "age high"))
#' collapse_shap(S, collapse)
#'
#' # Arrays (as with SHAP interactions)
#' S_inter <- array(1, dim = c(2, 4, 4), dimnames = list(NULL, letters[1:4], letters[1:4]))
#' collapse_shap(S_inter, collapse = list(cd = c("c", "d"), ab = c("a", "b")))
collapse_shap <- function(S, collapse = NULL, ...) {
if (is.null(collapse) || length(collapse) == 0L) {
return(S)
}
stopifnot(
is.matrix(S) || is.array(S),
length(dim(S)) <= 3L,
!is.null(colnames(S)),
"'collapse' must be a named list" = is.list(collapse) && !is.null(names(collapse)),
"'collapse' can't have duplicated names" = !anyDuplicated(names(collapse))
)
if (length(dim(S)) == 3L) {
nms <- dimnames(S)
stopifnot(!is.null(nms[[3L]]), nms[[2L]] == nms[[3L]])
}
u <- unlist(collapse, use.names = FALSE, recursive = FALSE)
keep <- setdiff(colnames(S), u)
stopifnot(
"'collapse' cannot have overlapping vectors." = !anyDuplicated(u),
"Values of 'collapse' should be in colnames(S)" = all(u %in% colnames(S)),
"Names of 'collapse' must be different from untouched column names" =
!any(names(collapse) %in% keep)
)
# Matrix case is easy
if (length(dim(S)) == 2L) {
add <- do.call(
cbind,
lapply(collapse, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
)
return(cbind(S[, keep, drop = FALSE], add))
}
# 3D case is tricky - no abind() in base R...
for (v_to in names(collapse)) {
v_from <- collapse[[v_to]]
v_keep <- setdiff(colnames(S), v_from)
S_to <- array(
dim = c(nrow(S), length(v_keep) + 1L, length(v_keep) + 1L),
dimnames = list(NULL, c(v_keep, v_to), c(v_keep, v_to))
)
S_to[, v_keep, v_keep] <- S[, v_keep, v_keep]
S_to[, v_to, v_to] <- apply(S[, v_from, v_from, drop = FALSE], 1L, FUN = sum)
# If we assume symmetry, we could spare one of the following rows
S_to[, v_to, v_keep] <- apply(S[, v_from, v_keep, drop = FALSE], c(1L, 3L), FUN = sum)
S_to[, v_keep, v_to] <- apply(S[, v_keep, v_from, drop = FALSE], 1:2, FUN = sum)
S <- S_to
}
S
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.