R/collapse_shap.R

Defines functions collapse_shap

Documented in collapse_shap

#' Collapse SHAP values
#'
#' This function sums up SHAP values (or SHAP interaction values) of feature groups.
#' Typical application: SHAP values have been generated by a model with one or multiple
#' one-hot encoded variables, but the explanations should be done using the
#' original factor.
#'
#' @param S Either a (n x p) matrix of SHAP values or a (n x p x p) array of SHAP
#'   interaction values.
#' @param collapse A named list of character vectors. Each vector specifies the
#'   feature names whose SHAP values need to be summed up.
#'   The names determine the resulting collapsed column/dimension names.
#' @param ... Currently unused.
#' @returns A matrix of SHAP values, or an array of SHAP interaction values.
#' @export
#' @examples
#' S <- cbind(
#'   x = c(0.1, 0.1, 0.1),
#'   `age low` = c(0.2, -0.1, 0.1),
#'   `age mid` = c(0, 0.2, -0.2),
#'   `age high` = c(1, -1, 0)
#' )
#' collapse <- list(age = c("age low", "age mid", "age high"))
#' collapse_shap(S, collapse)
#'
#' # Arrays (as with SHAP interactions)
#' S_inter <- array(1, dim = c(2, 4, 4), dimnames = list(NULL, letters[1:4], letters[1:4]))
#' collapse_shap(S_inter, collapse = list(cd = c("c", "d"), ab = c("a", "b")))
collapse_shap <- function(S, collapse = NULL, ...) {
  if (is.null(collapse) || length(collapse) == 0L) {
    return(S)
  }
  stopifnot(
    is.matrix(S) || is.array(S),
    length(dim(S)) <= 3L,
    !is.null(colnames(S)),
    "'collapse' must be a named list" = is.list(collapse) && !is.null(names(collapse)),
    "'collapse' can't have duplicated names" = !anyDuplicated(names(collapse))
  )
  if (length(dim(S)) == 3L) {
    nms <- dimnames(S)
    stopifnot(!is.null(nms[[3L]]), nms[[2L]] == nms[[3L]])
  }
  u <- unlist(collapse, use.names = FALSE, recursive = FALSE)
  keep <- setdiff(colnames(S), u)
  stopifnot(
    "'collapse' cannot have overlapping vectors." = !anyDuplicated(u),
    "Values of 'collapse' should be in colnames(S)" = all(u %in% colnames(S)),
    "Names of 'collapse' must be different from untouched column names" =
      !any(names(collapse) %in% keep)
  )
  # Matrix case is easy
  if (length(dim(S)) == 2L) {
    add <- do.call(
      cbind,
      lapply(collapse, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
    )
    return(cbind(S[, keep, drop = FALSE], add))
  }
  # 3D case is tricky - no abind() in base R...
  for (v_to in names(collapse)) {
    v_from <- collapse[[v_to]]
    v_keep <- setdiff(colnames(S), v_from)
    S_to <- array(
      dim = c(nrow(S), length(v_keep) + 1L, length(v_keep) + 1L),
      dimnames = list(NULL, c(v_keep, v_to), c(v_keep, v_to))
    )
    S_to[, v_keep, v_keep] <- S[, v_keep, v_keep]
    S_to[, v_to, v_to]   <- apply(S[, v_from, v_from, drop = FALSE], 1L, FUN = sum)
    # If we assume symmetry, we could spare one of the following rows
    S_to[, v_to, v_keep] <- apply(S[, v_from, v_keep, drop = FALSE], c(1L, 3L), FUN = sum)
    S_to[, v_keep, v_to] <- apply(S[, v_keep, v_from, drop = FALSE], 1:2, FUN = sum)
    S <- S_to
  }
  S
}

Try the shapviz package in your browser

Any scripts or data that you put into this service are public.

shapviz documentation built on May 29, 2024, 2 a.m.