plot.shapr | R Documentation |
Plots the individual prediction explanations.
## S3 method for class 'shapr'
plot(
x,
plot_type = "bar",
digits = 3,
index_x_explain = NULL,
top_k_features = NULL,
col = NULL,
bar_plot_phi0 = TRUE,
bar_plot_order = "largest_first",
scatter_features = NULL,
scatter_hist = TRUE,
include_group_feature_means = FALSE,
beeswarm_cex = 1/length(index_x_explain)^(1/4),
...
)
x |
An |
plot_type |
Character.
Specifies the type of plot to produce.
|
digits |
Integer.
Number of significant digits to use in the feature description.
Applicable for |
index_x_explain |
Integer vector.
Which of the test observations to plot. E.g. if you have
explained 10 observations using |
top_k_features |
Integer.
How many features to include in the plot.
E.g. if you have 15 features in your model you can plot the 5 most important features,
for each explanation, by setting |
col |
Character vector (where length depends on plot type).
The color codes (hex codes or other names understood by If you want to alter the colors i the plot, the length of the |
bar_plot_phi0 |
Logical.
Whether to include |
bar_plot_order |
Character.
Specifies what order to plot the features with respect to the magnitude of the shapley values with
|
scatter_features |
Integer or character vector.
Only used for |
scatter_hist |
Logical.
Only used for |
include_group_feature_means |
Logical.
Whether to include the average feature value in a group on the y-axis or not.
If |
beeswarm_cex |
Numeric.
The cex argument of |
... |
Other arguments passed to underlying functions,
like |
See the examples below, or vignette("general_usage", package = "shapr")
for an examples of
how you should use the function.
ggplot object with plots of the Shapley value explanations
Martin Jullum, Vilde Ung, Lars Henry Berge Olsen
data("airquality")
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- head(airquality, -50)
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "empirical",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace("ggplot2", quietly = TRUE) && requireNamespace("ggbeeswarm", quietly = TRUE)) {
# The default plotting option is a bar plot of the Shapley values
# We draw bar plots for the first 4 observations
plot(x, index_x_explain = 1:4)
# We can also make waterfall plots
plot(x, plot_type = "waterfall", index_x_explain = 1:4)
# And only showing the 2 features with largest contribution
plot(x, plot_type = "waterfall", index_x_explain = 1:4, top_k_features = 2)
# Or scatter plots showing the distribution of the shapley values and feature values
plot(x, plot_type = "scatter")
# And only for a specific feature
plot(x, plot_type = "scatter", scatter_features = "Temp")
# Or a beeswarm plot summarising the Shapley values and feature values for all features
plot(x, plot_type = "beeswarm")
plot(x, plot_type = "beeswarm", col = c("red", "black")) # we can change colors
# Additional arguments can be passed to ggbeeswarm::geom_beeswarm() using the '...' argument.
# For instance, sometimes the beeswarm plots overlap too much.
# This can be fixed with the 'corral="wrap" argument.
# See ?ggbeeswarm::geom_beeswarm for more information.
plot(x, plot_type = "beeswarm", corral = "wrap")
}
# Example of scatter and beeswarm plot with factor variables
airquality$Month_factor <- as.factor(month.abb[airquality$Month])
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month_factor")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- airquality
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "ctree",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace("ggplot2", quietly = TRUE) && requireNamespace("ggbeeswarm", quietly = TRUE)) {
plot(x, plot_type = "scatter")
plot(x, plot_type = "beeswarm")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.