In this article, we provide more detail on the different plots for the interpretability methods, including how to set the number of points plotted for a given feature, setting the center of the plots, and additional options not discussed in the general overview.
For the examples in the additional articles, we use the well-known iris
dataset,
which provides information on 50 flowers from each of 3 species of iris, and
a random forest model for predicting the "Sepal.Length"
variable. To begin
making interpretability plots, we must start with an Interpreter
object,
which consists of the following parameters:
library(MASS) library(distillML) library(Rforestry) # Load in data data("iris") set.seed(491) data <- iris # Train a random forest on the data set forest <- forestry(x=data[,-1], y=data[,1]) # Create a predictor wrapper for the forest forest_predictor <- Predictor$new(model = forest, data=data, y="Sepal.Length", task = "regression") # We specify grid.size for clarity (grid.size = 50 by default) forest_interpreter <- Interpreter$new(forest_predictor, grid.size = 50) print(forest_interpreter)
The following parameters in the Interpreter
object determine different aspects
of the interpretability plots:
grid.size
: This provides the number of grid points we want for continuous
features for PDP, ICE, and ALE methods. A larger input for grid.size
creates more points that are plotted, which increases the fidelity
of the plot and the computation time. This is specified when the
Interpreter
object is declared.grid.points
: This stores lists of grid points for each feature in the data.
For continuous features, the grid points are equally spaced
values from the minimum value to maximum value of the feature,
with the length specified by grid.size
. This parameter determines
which points to plot PDP and ICE methods.center.at
: This stores a list of values to center the continuous features plotted
for PDP and ICE curves. The center is defined as the value for a
feature that we want to center at, rather than a particular value
for the PDP or ICE curve. By default, this is set to the minimum
value in the grid.points
list for each continuous feature. Below, we demonstrate examples of these parameters for feature Sepal.Width
:
# The values of Sepal.Width to be plotted by PDP and ICE curves print(forest_interpreter$grid.points$Sepal.Width) # The number of grid.points is equal to the grid.size print(length(forest_interpreter$grid.points$Sepal.Width)) print(forest_interpreter$grid.size) # The value of Sepal.Width we center the PDP and ICE curves print(forest_interpreter$center.at$Sepal.Width) # plot the PDP and ICE curves plot(forest_interpreter, features = "Sepal.Width")
After initializing the Interpreter
, we can still change the centers and grid
points of the plot. For example, we provide some code below to provide a new set
of grid points and center at the mean value for Sepal.Width
. Note that when
we set a new center value, it must be within the range of the grid points.
# Set new grid points set.grid.points(forest_interpreter, "Sepal.Width", values = seq(2, 4.5, length.out = 100)) # Set new center set.center.at(forest_interpreter, "Sepal.Width", mean(seq(2, 4.5, length.out = 100))) # New plot plot(forest_interpreter, features = "Sepal.Width")
In contrast to the fixed grid points for PDP and ICE methods, the ALE method
sets its grid points based on quantiles of the marginal distribution.
By doing so, it avoids sparsity within any one neighborhood. The ALE is also
always mean-centered. As a result, methods such as set.grid.points
and
set.center.at
do not affect the ALE plots. The key parameter for the ALE method
is grid.size
, which determines the number of points calculated by the ALE.
# ALE plot plot(forest_interpreter, features = "Sepal.Width", method = "ale")
This package also provides a new feature to cluster the ICE curves. The default
option of plot
for the Interpreter
class plots the ICE curves and their mean,
the PDP
curve. We introduce the use of the kmeans
unsupervised learning algorithm
as a way to better visualize groups of ICE curves. To do so, we set method = "ice"
in the plot
function, and can set the number of clusters and the type of
clustering with the clusters
and clusterType
arguments respectively.
# Clustering Based on the Predicted Values of the ICE Curves plot(forest_interpreter, features = "Sepal.Width", method = "ice", clusters = 4, clusterType = "preds") # Clustering Based on the Change in Predicted Values of the ICE Curves plot(forest_interpreter, features = "Sepal.Width", method = "ice", clusters = 4, clusterType = "gradient")
The two options for clusterType
differ as follows:
preds
: Each predicted value for an ICE curve is treated as an entry in a vector,
and the ICE curves as grouped based on these prediccted values.
gradient
: Rather than clustering based on the predicted values, the gradient method
takes the change in predicted values across consecutive grid points, and
clusters based on these changes.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.