explain: Fast approximate Shapley values

Description Usage Arguments Value Note See Also Examples

View source: R/explain.R

Description

Compute fast (approximate) Shapley values for a set of features.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
explain(object, ...)

## Default S3 method:
explain(
  object,
  feature_names = NULL,
  X = NULL,
  nsim = 1,
  pred_wrapper = NULL,
  newdata = NULL,
  adjust = FALSE,
  ...
)

## S3 method for class 'lm'
explain(
  object,
  feature_names = NULL,
  X,
  nsim = 1,
  pred_wrapper,
  newdata = NULL,
  exact = FALSE,
  ...
)

## S3 method for class 'xgb.Booster'
explain(
  object,
  feature_names = NULL,
  X = NULL,
  nsim = 1,
  pred_wrapper,
  newdata = NULL,
  exact = FALSE,
  ...
)

Arguments

object

A fitted model object (e.g., a ranger or an xgboost object).

...

Additional optional arguments to be passed on to laply.

feature_names

Character string giving the names of the predictor variables (i.e., features) of interest. If NULL (default) they will be taken from the column names of X.

X

A matrix-like R object (e.g., a data frame or matrix) containing ONLY the feature columns from the training data. NOTE: This argument is required whenever exact = FALSE.

nsim

The number of Monte Carlo repetitions to use for estimating each Shapley value (only used when exact = FALSE). Default is 1. NOTE: To obtain the most accurate results, nsim should be set as large as feasibly possible.

pred_wrapper

Prediction function that requires two arguments, object and newdata. NOTE: This argument is required whenever exact = FALSE. The output of this function should be determined according to:

Regression

A numeric vector of predicted outcomes.

Binary classification

A vector of predicted class probabilities for the reference class.

Multiclass classification

A vector of predicted class probabilities for the reference class.

newdata

A matrix-like R object (e.g., a data frame or matrix) containing ONLY the feature columns for the observation(s) of interest; that is, the observation(s) you want to compute explanations for. Default is NULL which will produce approximate Shapley values for all the rows in X (i.e., the training data).

adjust

Logical indicating whether or not to adjust the sum of the estimated Shapley values to satisfy the additivity (or local accuracy) property; that is, to equal the difference between the model's prediction for that sample and the average prediction over all the training data (i.e., X).

exact

Logical indicating whether to compute exact Shapley values. Currently only available for lm and xgboost objects. Default is FALSE. Note that setting exact = TRUE will return explanations for each of the terms in an lm object.

Value

A tibble with one column for each feature specified in feature_names (if feature_names = NULL, the default, there will be one column for each feature in X) and one row for each observation in newdata (if newdata = NULL, the default, there will be one row for each observation in X).

Note

Setting exact = TRUE with a linear model (i.e., an lm or glm object) assumes that the input features are independent. Also, setting adjust = TRUE is experimental and we follow the same approach as in shap.

See Also

You can find more examples (with larger and more realistic data sets) on the fastshap GitHub repository: https://github.com/bgreenwell/fastshap.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#
# A projection pursuit regression (PPR) example
#

# Load the sample data; see ?datasets::mtcars for details
data(mtcars)

# Fit a projection pursuit regression model
fit <- lm(mpg ~ ., data = mtcars)

# Compute approximate Shapley values using 10 Monte Carlo simulations
set.seed(101)  # for reproducibility
shap <- explain(fit, X = subset(mtcars, select = -mpg), nsim = 10, 
                pred_wrapper = predict)
shap

# Compute exact Shapley (i.e., LinearSHAP) values
shap <- explain(fit, exact = TRUE)
shap

# Shapley-based plots
library(ggplot2)
autoplot(shap)  # Shapley-based importance plot
autoplot(shap, type = "dependence", feature = "wt", X = mtcars)
autoplot(shap, type = "contribution", row_num = 1)  # explain first row of X

Example output

Warning message:
The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
Using compatibility `.name_repair`.
This warning is displayed once every 8 hours.
Call `lifecycle::last_warnings()` to see where this warning was generated. 
# A tibble: 32 x 10
       cyl     disp      hp    drat     wt   qsec      vs     am    gear    carb
     <dbl>    <dbl>   <dbl>   <dbl>  <dbl>  <dbl>   <dbl>  <dbl>   <dbl>   <dbl>
 1  0.0223 -1.40     0.977   0.443   1.94   0.263 -0.0636  1.26   0.328  -0.239 
 2  0.134   0.00960  0.516   0.422   1.50  -0.863 -0.159   1.01  -0.0655 -0.199 
 3  0.290  -1.90     0.969   0.312   3.43   0.981  0.127   1.26   0.262   0.419 
 4 -0.0223  0.768    1.35   -0.582  -1.06   0.291  0.254  -0.504 -0.524   0.339 
 5 -0.0446  2.05    -1.23   -0.128  -2.01  -0.795 -0.0636 -0.756 -0.328   0.259 
 6  0.0446 -0.119    0.589  -0.841  -1.70   1.60   0.159  -1.01  -0.655   0.319 
 7 -0.178   1.49    -1.38   -0.351  -0.875 -1.98  -0.222  -2.02  -0.655  -0.239 
 8  0.223  -1.08     1.92   -0.0575  0.187  1.64   0.127  -1.01  -0.0655  0.120 
 9  0.312  -1.84     1.24    0.260   2.00   4.53   0.127  -0.504  0.328   0.0798
10  0.0446 -1.78     0.0924  0.260  -1.66   0.372  0.159  -1.26   0.131  -0.140 
# … with 22 more rows
# A tibble: 32 x 10
       cyl    disp     hp    drat       wt   qsec     vs    am   gear   carb
     <dbl>   <dbl>  <dbl>   <dbl>    <dbl>  <dbl>  <dbl> <dbl>  <dbl>  <dbl>
 1  0.0209 -0.943   0.788  0.239   2.22    -1.14  -0.139  1.50  0.205 -0.237
 2  0.0209 -0.943   0.788  0.239   1.27    -0.680 -0.139  1.50  0.205 -0.237
 3  0.244  -1.64    1.15   0.199   3.33     0.625  0.179  1.50  0.205  0.361
 4  0.0209  0.364   0.788 -0.407   0.00836  1.31   0.179 -1.02 -0.451  0.361
 5 -0.202   1.72   -0.608 -0.351  -0.828   -0.680 -0.139 -1.02 -0.451  0.162
 6  0.0209 -0.0763  0.896 -0.658  -0.902    1.95   0.179 -1.02 -0.451  0.361
 7 -0.202   1.72   -2.11  -0.304  -1.31    -1.65  -0.139 -1.02 -0.451 -0.237
 8  0.244  -1.12    1.82   0.0735  0.101    1.77   0.179 -1.02  0.205  0.162
 9  0.244  -1.20    1.11   0.255   0.250    4.15   0.179 -1.02  0.205  0.162
10  0.0209 -0.842   0.509  0.255  -0.828    0.370  0.179 -1.02  0.205 -0.237
# … with 22 more rows

fastshap documentation built on March 26, 2020, 7:57 p.m.