Nothing
#' SHAP Dependence Plot
#'
#' Scatterplot of the SHAP values of a feature against its feature values.
#' If SHAP interaction values are available, setting `interactions = TRUE` allows
#' to focus on pure interaction effects (multiplied by two) or on pure main effects.
#' By default, the feature on the color scale is selected via SHAP interactions
#' (if available) or an interaction heuristic, see [potential_interactions()].
#'
#' @importFrom rlang .data
#'
#' @param object An object of class "(m)shapviz".
#' @param v Column name of feature to be plotted. Can be a vector/list if `object` is
#' of class "shapviz".
#' @param color_var Feature name to be used on the color scale to investigate
#' interactions. The default ("auto") uses SHAP interaction values (if available),
#' or a heuristic to select the strongest interacting feature. Set to `NULL` to not
#' use the color axis. Can be a vector/list if `object` is of class "shapviz".
#' @param color Color to be used if `color_var = NULL`. Can be a vector/list if `v`
#' is a vector.
#' @param viridis_args List of viridis color scale arguments, see
#' `?ggplot2::scale_color_viridis_c`. The default points to the global option
#' `shapviz.viridis_args`, which corresponds to
#' `list(begin = 0.25, end = 0.85, option = "inferno")`.
#' These values are passed to `ggplot2::scale_color_viridis_*()`.
#' For example, to switch to a standard viridis scale, you can either change the
#' default via `options(shapviz.viridis_args = list())`, or set
#' `viridis_args = list()`. Only relevant if `color_var` is not `NULL`.
#' @param jitter_width The amount of horizontal jitter. The default (`NULL`) will
#' use a value of 0.2 in case `v` is discrete, and no jitter otherwise.
#' (Numeric variables are considered discrete if they have at most 7 unique values.)
#' Can be a vector/list if `v` is a vector.
#' @param interactions Should SHAP interaction values be plotted? Default is `FALSE`.
#' Requires SHAP interaction values. If `color_var = NULL` (or it is equal to `v`),
#' the pure main effect of `v` is visualized. Otherwise, twice the SHAP interaction
#' values between `v` and the `color_var` are plotted.
#' @param ih_nbins,ih_color_num,ih_scale,ih_adjusted Interaction heuristic (ih)
#' parameters used to select the color variable, see [potential_interactions()].
#' Only used if `color_var = "auto"` and if there are no SHAP interaction values.
#' @param ... Arguments passed to [ggplot2::geom_jitter()].
#' @returns An object of class "ggplot" (or "patchwork") representing a dependence plot.
#' @examples
#' dtrain <- xgboost::xgb.DMatrix(
#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
#' )
#' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
#' x <- shapviz(fit, X_pred = dtrain, X = iris)
#' sv_dependence(x, "Petal.Length")
#' sv_dependence(x, "Petal.Length", color_var = "Species")
#' sv_dependence(x, "Petal.Length", color_var = NULL)
#' sv_dependence(x, c("Species", "Petal.Length"))
#' sv_dependence(x, "Petal.Width", color_var = c("Species", "Petal.Length"))
#'
#' # SHAP interaction values/main effects
#' x2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
#' sv_dependence(x2, "Petal.Length", interactions = TRUE)
#' sv_dependence(
#' x2, c("Petal.Length", "Species"), color_var = NULL, interactions = TRUE
#' )
#' @export
#' @seealso [potential_interactions()]
sv_dependence <- function(object, ...) {
UseMethod("sv_dependence")
}
#' @describeIn sv_dependence
#' Default method.
#' @export
sv_dependence.default <- function(object, ...) {
stop("No default method available.")
}
#' @describeIn sv_dependence
#' SHAP dependence plot for "shapviz" object.
#' @export
sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
p <- length(v)
if (p > 1L || length(color_var) > 1L) {
if (is.null(color_var)) {
color_var <- replicate(p, NULL)
}
if (is.null(jitter_width)) {
jitter_width <- replicate(p, NULL)
}
plot_list <- mapply(
FUN = sv_dependence,
v = v,
color_var = color_var,
color = color,
jitter_width = jitter_width,
MoreArgs = list(
object = object,
viridis_args = viridis_args,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
),
SIMPLIFY = FALSE
)
nms <- if (length(v) > 1L) v
plot_list <- add_titles(plot_list, nms = nms) # see sv_waterfall()
return(patchwork::wrap_plots(plot_list))
}
S <- get_shap_values(object)
X <- get_feature_values(object)
S_inter <- get_shap_interactions(object)
nms <- colnames(object)
stopifnot(
v %in% nms,
is.null(color_var) || (color_var %in% c("auto", nms))
)
if (interactions && is.null(S_inter)) {
stop("No SHAP interaction values available in 'object'.")
}
# Set jitter value
if (is.null(jitter_width)) {
jitter_width <- 0.2 * .is_discrete(X[[v]], n_unique = 7L)
}
# Set color value if "auto"
if (!is.null(color_var) && color_var == "auto" && !("auto" %in% nms)) {
scores <- potential_interactions(
object,
v,
nbins = ih_nbins,
color_num = ih_color_num,
scale = ih_scale,
adjusted = ih_adjusted
)
# 'scores' can be NULL, or a sorted vector like c(0.1, 0, -0.01, NA)
# Thus, let's take the first positive one (or NULL)
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
color_var <- if (length(scores) >= 1L) names(scores)[1L]
}
if (isTRUE(interactions)) {
if (is.null(color_var)) {
color_var <- v
}
if (color_var == v) {
y_lab <- "SHAP main effect"
} else {
y_lab <- "SHAP interaction"
}
s <- S_inter[, v, color_var]
if (color_var != v) {
s <- 2 * s # Off-diagonals need to be multiplied by 2 for symmetry reasons
}
} else {
y_lab <- "SHAP value"
s <- S[, v]
}
dat <- data.frame(s, X[[v]])
colnames(dat) <- c("shap", v)
if (is.null(color_var) || color_var == v) {
p <- ggplot2::ggplot(dat, ggplot2::aes(x = .data[[v]], y = shap)) +
ggplot2::geom_jitter(color = color, width = jitter_width, height = 0, ...) +
ggplot2::ylab(y_lab)
return(p)
}
dat[[color_var]] <- X[[color_var]]
if (.is_discrete(dat[[color_var]], n_unique = 0L)) { # only if non-numeric
vir <- ggplot2::scale_color_viridis_d
} else {
vir <- ggplot2::scale_color_viridis_c
}
if (is.null(viridis_args)) {
viridis_args <- list()
}
ggplot2::ggplot(
dat, ggplot2::aes(x = .data[[v]], y = shap, color = .data[[color_var]])
) +
ggplot2::geom_jitter(width = jitter_width, height = 0, ...) +
ggplot2::ylab(y_lab) +
do.call(vir, viridis_args) +
ggplot2::theme(legend.box.spacing = grid::unit(0, "pt"))
}
#' @describeIn sv_dependence
#' SHAP dependence plot for "mshapviz" object.
#' @export
sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
stopifnot(
length(v) == 1L,
length(color_var) <= 1L
)
plot_list <- lapply(
object,
FUN = sv_dependence,
# Argument list (simplify via match.call() or some rlang magic?)
v = v,
color_var = color_var,
color = color,
viridis_args = viridis_args,
jitter_width = jitter_width,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
)
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
patchwork::wrap_plots(plot_list)
}
# Helper functions
# Checks if z is discrete
.is_discrete <- function(z, n_unique) {
is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.