sv_interaction: SHAP Interaction Plot

View source: R/sv_interaction.R

sv_interactionR Documentation

SHAP Interaction Plot

Description

Plots a beeswarm plot for each feature pair. Diagonals represent the main effects, while off-diagonals show interactions (multiplied by two due to symmetry). The colors on the beeswarm plots represent min-max scaled feature values. Non-numeric features are transformed to numeric by calling data.matrix() first. The features are sorted in decreasing order of usual SHAP importance.

Usage

sv_interaction(object, ...)

## Default S3 method:
sv_interaction(object, ...)

## S3 method for class 'shapviz'
sv_interaction(
  object,
  kind = c("beeswarm", "no"),
  max_display = 7L,
  alpha = 0.3,
  bee_width = 0.3,
  bee_adjust = 0.5,
  viridis_args = getOption("shapviz.viridis_args"),
  color_bar_title = "Row feature value",
  sort_features = TRUE,
  ...
)

## S3 method for class 'mshapviz'
sv_interaction(
  object,
  kind = c("beeswarm", "no"),
  max_display = 7L,
  alpha = 0.3,
  bee_width = 0.3,
  bee_adjust = 0.5,
  viridis_args = getOption("shapviz.viridis_args"),
  color_bar_title = "Row feature value",
  sort_features = TRUE,
  ...
)

Arguments

object

An object of class "(m)shapviz" containing element S_inter.

...

Arguments passed to ggplot2::geom_point(). For instance, passing size = 1 will produce smaller dots.

kind

Set to "no" to return the matrix of average absolute SHAP interactions (or a list of such matrices in case of object of class "mshapviz"). Due to symmetry, off-diagonals are multiplied by two. The default is "beeswarm".

max_display

How many features should be plotted? Set to Inf to show all features. Has no effect if kind = "no".

alpha

Transparency of the beeswarm dots. Defaults to 0.3.

bee_width

Relative width of the beeswarms.

bee_adjust

Relative bandwidth adjustment factor used in estimating the density of the beeswarms.

viridis_args

List of viridis color scale arguments. The default points to the global option shapviz.viridis_args, which corresponds to list(begin = 0.25, end = 0.85, option = "inferno"). These values are passed to ggplot2::scale_color_viridis_c(). For example, to switch to standard viridis, either change the default with options(shapviz.viridis_args = list()) or set viridis_args = list().

color_bar_title

Title of color bar of the beeswarm plot. Set to NULL to hide the color bar altogether.

sort_features

Should features be sorted or not? The default is TRUE.

Value

A "ggplot" (or "patchwork") object, or - if kind = "no" - a named numeric matrix of average absolute SHAP interactions sorted by the average absolute SHAP values (or a list of such matrices in case of "mshapviz" object).

Methods (by class)

  • sv_interaction(default): Default method.

  • sv_interaction(shapviz): SHAP interaction plot for an object of class "shapviz".

  • sv_interaction(mshapviz): SHAP interaction plot for an object of class "mshapviz".

See Also

sv_importance()

Examples

dtrain <- xgboost::xgb.DMatrix(
  data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
)
fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
x <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
sv_interaction(x, kind = "no")
sv_interaction(x, max_display = 2, size = 3)

shapviz documentation built on Sept. 14, 2024, 5:07 p.m.