Plotting Methods

In this article, we provide more detail on the different plots for the interpretability methods, including how to set the number of points plotted for a given feature, setting the center of the plots, and additional options not discussed in the general overview.

Interpreter Class Initialization

For the examples in the additional articles, we use the well-known iris dataset, which provides information on 50 flowers from each of 3 species of iris, and a random forest model for predicting the "Sepal.Length" variable. To begin making interpretability plots, we must start with an Interpreter object, which consists of the following parameters:

library(MASS)
library(distillML)
library(Rforestry)

# Load in data 
data("iris")
set.seed(491)
data <- iris

# Train a random forest on the data set
forest <- forestry(x=data[,-1],
                   y=data[,1])

# Create a predictor wrapper for the forest
forest_predictor <- Predictor$new(model = forest,
                                  data=data,
                                  y="Sepal.Length",
                                  task = "regression")

# We specify grid.size for clarity (grid.size = 50 by default)
forest_interpreter <- Interpreter$new(forest_predictor,
                                      grid.size = 50)
print(forest_interpreter)

The following parameters in the Interpreter object determine different aspects of the interpretability plots:

Below, we demonstrate examples of these parameters for feature Sepal.Width:

# The values of Sepal.Width to be plotted by PDP and ICE curves
print(forest_interpreter$grid.points$Sepal.Width)

# The number of grid.points is equal to the grid.size
print(length(forest_interpreter$grid.points$Sepal.Width))
print(forest_interpreter$grid.size)

# The value of Sepal.Width we center the PDP and ICE curves
print(forest_interpreter$center.at$Sepal.Width)

# plot the PDP and ICE curves
plot(forest_interpreter, features = "Sepal.Width")

After initializing the Interpreter, we can still change the centers and grid points of the plot. For example, we provide some code below to provide a new set of grid points and center at the mean value for Sepal.Width. Note that when we set a new center value, it must be within the range of the grid points.

# Set new grid points
set.grid.points(forest_interpreter, "Sepal.Width", 
                values = seq(2, 4.5, length.out = 100))

# Set new center
set.center.at(forest_interpreter, "Sepal.Width", 
              mean(seq(2, 4.5, length.out = 100)))

# New plot
plot(forest_interpreter, features = "Sepal.Width")

In contrast to the fixed grid points for PDP and ICE methods, the ALE method sets its grid points based on quantiles of the marginal distribution. By doing so, it avoids sparsity within any one neighborhood. The ALE is also always mean-centered. As a result, methods such as set.grid.points and set.center.at do not affect the ALE plots. The key parameter for the ALE method is grid.size, which determines the number of points calculated by the ALE.

# ALE plot
plot(forest_interpreter, features = "Sepal.Width", method = "ale")

Advanced Feature: Clustering

This package also provides a new feature to cluster the ICE curves. The default option of plot for the Interpreter class plots the ICE curves and their mean, the PDP curve. We introduce the use of the kmeans unsupervised learning algorithm as a way to better visualize groups of ICE curves. To do so, we set method = "ice" in the plot function, and can set the number of clusters and the type of clustering with the clusters and clusterType arguments respectively.

# Clustering Based on the Predicted Values of the ICE Curves
plot(forest_interpreter,
     features = "Sepal.Width",
     method = "ice",
     clusters = 4,
     clusterType = "preds")

# Clustering Based on the Change in Predicted Values of the ICE Curves
plot(forest_interpreter,
     features = "Sepal.Width",
     method = "ice",
     clusters = 4,
     clusterType = "gradient")

The two options for clusterType differ as follows: preds: Each predicted value for an ICE curve is treated as an entry in a vector, and the ICE curves as grouped based on these prediccted values. gradient: Rather than clustering based on the predicted values, the gradient method takes the change in predicted values across consecutive grid points, and clusters based on these changes.



forestry-labs/interpretability_sandbox documentation built on April 26, 2023, 4:14 p.m.