TreeSurrogate: Decision tree surrogate model

TreeSurrogateR Documentation

Decision tree surrogate model

Description

TreeSurrogate fits a decision tree on the predictions of a prediction model.

Details

A conditional inference tree is fitted on the predicted \hat{y} from the machine learning model and the data. The partykit package and function are used to fit the tree. By default a tree of maximum depth of 2 is fitted to improve interpretability.

To learn more about global surrogate models, read the Interpretable Machine Learning book: https://christophm.github.io/interpretable-ml-book/global.html

Super class

iml::InterpretationMethod -> TreeSurrogate

Public fields

tree

party
The fitted tree. See also partykit::ctree.

maxdepth

numeric(1)
The maximum tree depth.

r.squared

numeric(1|n.classes)
R squared measures how well the decision tree approximates the underlying model. It is calculated as 1 - (variance of prediction differences / variance of black box model predictions). For the multi-class case, r.squared contains one measure per class.

Methods

Public methods

Inherited methods

Method new()

Create a TreeSurrogate object

Usage
TreeSurrogate$new(predictor, maxdepth = 2, tree.args = NULL)
Arguments
predictor

Predictor
The object (created with Predictor$new()) holding the machine learning model and the data.

maxdepth

numeric(1)
The maximum depth of the tree. Default is 2.

tree.args

(named list)
Further arguments for party::ctree().


Method predict()

Predict new data with the tree. See also predict.TreeSurrogate

Usage
TreeSurrogate$predict(newdata, type = "prob", ...)
Arguments
newdata

data.frame
Data to predict on.

type

Prediction type.

...

Further arguments passed to predict().


Method clone()

The objects of this class are cloneable with this method.

Usage
TreeSurrogate$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

References

Craven, M., & Shavlik, J. W. (1996). Extracting tree-structured representations of trained networks. In Advances in neural information processing systems (pp. 24-30).

See Also

predict.TreeSurrogate plot.TreeSurrogate

For the tree implementation partykit::ctree()

Examples

library("randomForest")
# Fit a Random Forest on the Boston housing data set
data("Boston", package = "MASS")
rf <- randomForest(medv ~ ., data = Boston, ntree = 50)
# Create a model object
mod <- Predictor$new(rf, data = Boston[-which(names(Boston) == "medv")])

# Fit a decision tree as a surrogate for the whole random forest
dt <- TreeSurrogate$new(mod)

# Plot the resulting leaf nodes
plot(dt)

# Use the tree to predict new data
predict(dt, Boston[1:10, ])

# Extract the results
dat <- dt$results
head(dat)

# It also works for classification
rf <- randomForest(Species ~ ., data = iris, ntree = 50)
X <- iris[-which(names(iris) == "Species")]
mod <- Predictor$new(rf, data = X, type = "prob")

# Fit a decision tree as a surrogate for the whole random forest
dt <- TreeSurrogate$new(mod, maxdepth = 2)

# Plot the resulting leaf nodes
plot(dt)

# If you want to visualize the tree directly:
plot(dt$tree)

# Use the tree to predict new data
set.seed(42)
iris.sample <- X[sample(1:nrow(X), 10), ]
predict(dt, iris.sample)
predict(dt, iris.sample, type = "class")

# Extract the dataset
dat <- dt$results
head(dat)

christophM/iml documentation built on April 1, 2024, 1:26 p.m.