knitr::opts_chunk$set( collapse = TRUE, comment = "#>", warning = FALSE, message = FALSE, fig.width = 5.5, fig.height = 4.5 )
No black-box model without XAI. This is where packages like
{flashlight} offers the following XAI methods:
light_performance()
: Performance metrics like RMSE and/or $R^2$light_importance()
: Permutation variable importance [@fisher]light_ice()
: Individual conditional expectation (ICE) profiles [@goldstein] (centered or uncentered)light_profile()
: Partial dependence [@friedman2001], accumulated local effects (ALE) [@apley], average predicted/observed/residuallight_profile2d()
: Two-dimensional version of light_profile()
light_effects()
: Combines partial dependence, ALE, response and prediction profileslight_interaction()
: Different variants of Friedman's H statistics [@friedman2008]light_breakdown()
: Variable contribution breakdown (approximate SHAP) for single observations [@gosiewska]light_global_surrogate()
: Global surrogate trees [@molnar]Good to know:
flashlight
(see examples and Section "flashlights"). multiflashlight()
.plot()
visualizes the results via {ggplot2}.# From CRAN install.packages("flashlight") # Development version devtools::install_github("mayer79/flashlight")
Let's start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.
library(ggplot2) library(MetricsWeighted) library(flashlight) fit_lm <- lm(Sepal.Length ~ ., data = iris) # Make explainer object fl_lm <- flashlight( model = fit_lm, data = iris, y = "Sepal.Length", label = "lm", metrics = list(RMSE = rmse, `R-squared` = r_squared) )
fl_lm |> light_performance() |> plot(fill = "darkred") + labs(x = element_blank(), title = "Performance on training data") fl_lm |> light_performance(by = "Species") |> plot(fill = "darkred") + ggtitle("Performance split by Species")
Error bars represent standard errors, i.e., the uncertainty of the estimated importance.
fl_lm |> light_importance(m_repetitions = 4) |> plot(fill = "darkred") + labs(title = "Permutation importance", y = "Increase in RMSE")
Petal.Width
fl_lm |> light_ice("Sepal.Width", n_max = 200) |> plot(alpha = 0.3, color = "chartreuse4") + labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction") fl_lm |> light_ice("Sepal.Width", n_max = 200, center = "middle") |> plot(alpha = 0.3, color = "chartreuse4") + labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")
fl_lm |> light_profile("Sepal.Width", n_bins = 40) |> plot() + ggtitle("PDP for 'Sepal.Width'") fl_lm |> light_profile("Sepal.Width", n_bins = 40, by = "Species") |> plot() + ggtitle("Same grouped by 'Species'")
fl_lm |> light_profile2d(c("Petal.Width", "Petal.Length")) |> plot()
fl_lm |> light_profile("Sepal.Width", type = "ale") |> plot() + ggtitle("ALE plot for 'Sepal.Width'")
fl_lm |> light_effects("Sepal.Width") |> plot(use = "all") + ggtitle("Different types of profiles for 'Sepal.Width'")
fl_lm |> light_breakdown(new_obs = iris[1, ]) |> plot()
fl_lm |> light_global_surrogate() |> plot()
Multiple flashlights can be combined to a multiflashlight.
library(rpart) fit_tree <- rpart( Sepal.Length ~ ., data = iris, control = list(cp = 0, xval = 0, maxdepth = 5) ) # Make explainer object fl_tree <- flashlight( model = fit_tree, data = iris, y = "Sepal.Length", label = "tree", metrics = list(RMSE = rmse, `R-squared` = r_squared) ) # Combine with other explainer fls <- multiflashlight(list(fl_tree, fl_lm)) fls |> light_performance() |> plot(fill = "chartreuse4") + labs(x = "Model", title = "Performance") fls |> light_importance() |> plot(fill = "chartreuse4") + labs(y = "Increase in RMSE", title = "Permutation importance") fls |> light_profile("Petal.Length", n_bins = 40) |> plot() + ggtitle("PDP") fls |> light_profile("Petal.Length", n_bins = 40, by = "Species") |> plot() + ggtitle("PDP by Species")
The "flashlight" explainer expects the following information:
model
: Fitted model. Currently, this argument must be named.data
: Reference data used to calculate things, often part of the validation data.y
: Column name in data
corresponding to the numeric response.predict_function
: function of the same signature as stats::predict()
. It takes a model
and a data.frame data
, and provides numeric predictions, see below for more details.linkinv
: Optional function applied to the output of predict_function()
. Should actually be called "trafo".w
: Optional column name in data
corresponding to case weights.by
: Optional column name in data
used to group the results. Must be discrete.metrics
: List of metrics, by default list(rmse = MetricsWeighted::rmse)
. For binary (probabilistic) classification, good candidate metrics would be MetricsWeighted::logLoss
.label
: Mandatory name of the model.predict_function
s (a selection)The default stats::predict()
works for models of class
lm()
, glm()
(for predictions on link scale), andrpart()
.It also works for meta-learner models like
Manual prediction functions are, e.g., required for
function(m, X) predict(m, X)$predictions
for regression, and
function(m, X) predict(m, X)$predictions[, 2]
for probabilistic binary classificationglm()
: Use function(m, X) predict(m, X, type = "response")
to get GLM predictions at the response scaleA bit more complicated are models whose native predict function do not work on data.frames:
Example (XGBoost):
This works when non-numeric features are all factors (not categoricals):
x <- vector of features predict_function = function(m, df) predict(m, data.matrix(df[x]))
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.