Introduction to iml: Interpretable Machine Learning in R

knitr::opts_chunk$set(
  comment = "#>",
  fig.width = 7,
  fig.height = 7,
  fig.align = "center"
)
options(
  tibble.print_min = 4L,
  tibble.print_max = 4L
)

Machine learning models usually perform really well for predictions, but are not interpretable. The iml package provides tools for analysing any black box machine learning model:

This document shows you how to use the iml package to analyse machine learning models.

If you want to learn more about the technical details of all the methods, read chapters from: https://christophm.github.io/interpretable-ml-book/agnostic.html

Data: Boston Housing

We'll use the MASS::Boston dataset to demonstrate the abilities of the iml package. This dataset contains median house values from Boston neighbourhoods.

data("Boston", package = "MASS")
head(Boston)

Fitting the machine learning model

First we train a randomForest to predict the Boston median housing value:

set.seed(42)
library("iml")
library("randomForest")
data("Boston", package = "MASS")
rf <- randomForest(medv ~ ., data = Boston, ntree = 10)

Using the iml Predictor() container

We create a Predictor object, that holds the model and the data. The iml package uses R6 classes: New objects can be created by calling Predictor$new().

X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(rf, data = X, y = Boston$medv)

Feature importance

We can measure how important each feature was for the predictions with FeatureImp. The feature importance measure works by shuffling each feature and measuring how much the performance drops. For this regression task we choose to measure the loss in performance with the mean absolute error ('mae'), another choice would be the mean squared error ('mse').

Once we create a new object of FeatureImp, the importance is automatically computed. We can call the plot() function of the object or look at the results in a data.frame.

imp <- FeatureImp$new(predictor, loss = "mae")
library("ggplot2")
plot(imp)
imp$results

Feature effects

Besides knowing which features were important, we are interested in how the features influence the predicted outcome. The FeatureEffect class implements accumulated local effect plots, partial dependence plots and individual conditional expectation curves. The following plot shows the accumulated local effects (ALE) for the feature 'lstat'. ALE shows how the prediction changes locally, when the feature is varied. The marks on the x-axis indicates the distribution of the 'lstat' feature, showing how relevant a region is for interpretation (little or no points mean that we should not over-interpret this region).

ale <- FeatureEffect$new(predictor, feature = "lstat", grid.size = 10)
ale$plot()

If we want to compute the partial dependence curves on another feature, we can simply reset the feature:

ale$set.feature("rm")
ale$plot()

Measure interactions

We can also measure how strongly features interact with each other. The interaction measure regards how much of the variance of $f(x)$ is explained by the interaction. The measure is between 0 (no interaction) and 1 (= 100% of variance of $f(x)$ due to interactions). For each feature, we measure how much they interact with any other feature:

interact <- Interaction$new(predictor, grid.size = 15)
plot(interact)

We can also specify a feature and measure all it's 2-way interactions with all other features:

interact <- Interaction$new(predictor, feature = "crim", grid.size = 15)
plot(interact)

You can also plot the feature effects for all features at once:

effs <- FeatureEffects$new(predictor, grid.size = 10)
plot(effs)

Surrogate model

Another way to make the models more interpretable is to replace the black box with a simpler model - a decision tree. We take the predictions of the black box model (in our case the random forest) and train a decision tree on the original features and the predicted outcome. The plot shows the terminal nodes of the fitted tree. The maxdepth parameter controls how deep the tree can grow and therefore how interpretable it is.

tree <- TreeSurrogate$new(predictor, maxdepth = 2)
plot(tree)

We can use the tree to make predictions:

head(tree$predict(Boston))

Explain single predictions with a local model

Global surrogate model can improve the understanding of the global model behaviour. We can also fit a model locally to understand an individual prediction better. The local model fitted by LocalModel is a linear regression model and the data points are weighted by how close they are to the data point for wich we want to explain the prediction.

lime.explain <- LocalModel$new(predictor, x.interest = X[1, ])
lime.explain$results
plot(lime.explain)
lime.explain$explain(X[2, ])
plot(lime.explain)

Explain single predictions with game theory

An alternative for explaining individual predictions is a method from coalitional game theory named Shapley value. Assume that for one data point, the feature values play a game together, in which they get the prediction as a payout. The Shapley value tells us how to fairly distribute the payout among the feature values.

shapley <- Shapley$new(predictor, x.interest = X[1, ], sample.size = 50)
shapley$plot()

We can reuse the object to explain other data points:

shapley$explain(x.interest = X[2, ])
shapley$plot()

The results in data.frame form can be extracted like this:

results <- shapley$results
head(results)


Try the iml package in your browser

Any scripts or data that you put into this service are public.

iml documentation built on May 29, 2024, 1:59 a.m.