plot.shapr | R Documentation |
Plots the individual prediction explanations.
## S3 method for class 'shapr'
plot(
x,
plot_type = "bar",
digits = 3,
index_x_explain = NULL,
top_k_features = NULL,
col = NULL,
bar_plot_phi0 = TRUE,
bar_plot_order = "largest_first",
scatter_features = NULL,
scatter_hist = TRUE,
include_group_feature_means = FALSE,
beeswarm_cex = 1/length(index_x_explain)^(1/4),
...
)
x |
An |
plot_type |
Character.
Specifies the type of plot to produce.
|
digits |
Integer.
Number of significant digits to use in the feature description.
Applicable for |
index_x_explain |
Integer vector.
Which of the test observations to plot. For example, if you have
explained 10 observations using |
top_k_features |
Integer.
How many features to include in the plot.
E.g. if you have 15 features in your model you can plot the 5 most important features,
for each explanation, by setting |
col |
Character vector (where length depends on plot type).
The color codes (hex codes or other names understood by If you want to alter the colors in the plot, the length of the |
bar_plot_phi0 |
Logical.
Whether to include |
bar_plot_order |
Character.
Specifies what order to plot the features with respect to the magnitude of the Shapley values with
|
scatter_features |
Integer or character vector.
Only used for |
scatter_hist |
Logical.
Only used for |
include_group_feature_means |
Logical.
Whether to include the average feature value in a group on the y-axis or not.
If |
beeswarm_cex |
Numeric.
The cex argument of |
... |
Other arguments passed to underlying functions,
like |
See the examples below, or vignette("general_usage", package = "shapr")
for examples of
how to use the function.
ggplot object with plots of the Shapley value explanations
Martin Jullum, Vilde Ung, Lars Henry Berge Olsen
if (requireNamespace("party", quietly = TRUE)) {
data("airquality")
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- head(airquality, -50)
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "empirical",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace(c("ggplot2", "ggbeeswarm"), quietly = TRUE)) {
# The default plotting option is a bar plot of the Shapley values
# We draw bar plots for the first 4 observations
plot(x, index_x_explain = 1:4)
# We can also make waterfall plots
plot(x, plot_type = "waterfall", index_x_explain = 1:4)
# And only showing the two features with the largest contributions
plot(x, plot_type = "waterfall", index_x_explain = 1:4, top_k_features = 2)
# Or scatter plots showing the distribution of the Shapley values and feature values
plot(x, plot_type = "scatter")
# And only for a specific feature
plot(x, plot_type = "scatter", scatter_features = "Temp")
# Or a beeswarm plot summarising the Shapley values and feature values for all features
plot(x, plot_type = "beeswarm")
plot(x, plot_type = "beeswarm", col = c("red", "black")) # we can change colors
# Additional arguments can be passed to ggbeeswarm::geom_beeswarm() using the '...' argument.
# For instance, sometimes the beeswarm plots overlap too much.
# This can be fixed with the 'corral="wrap" argument.
# See ?ggbeeswarm::geom_beeswarm for more information.
plot(x, plot_type = "beeswarm", corral = "wrap")
}
# Example of scatter and beeswarm plot with factor variables
airquality$Month_factor <- as.factor(month.abb[airquality$Month])
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month_factor")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- airquality
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "ctree",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace(c("ggplot2", "ggbeeswarm"), quietly = TRUE)) {
plot(x, plot_type = "scatter")
plot(x, plot_type = "beeswarm")
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.