TreeSurrogate | R Documentation |
TreeSurrogate
fits a decision tree on the predictions of a prediction model.
A conditional inference tree is fitted on the predicted \hat{y} from
the machine learning model and the data. The partykit
package and
function are used to fit the tree. By default a tree of maximum depth of 2 is
fitted to improve interpretability.
To learn more about global surrogate models, read the Interpretable Machine Learning book: https://christophm.github.io/interpretable-ml-book/global.html
iml::InterpretationMethod
-> TreeSurrogate
tree
party
The fitted tree. See also partykit::ctree.
maxdepth
numeric(1)
The maximum tree depth.
r.squared
numeric(1|n.classes)
R squared measures how well the decision tree approximates the
underlying model. It is calculated as 1 - (variance of prediction
differences / variance of black box model predictions). For the
multi-class case, r.squared contains one measure per class.
new()
Create a TreeSurrogate object
TreeSurrogate$new(predictor, maxdepth = 2, tree.args = NULL)
predictor
Predictor
The object (created with Predictor$new()
) holding the machine
learning model and the data.
maxdepth
numeric(1)
The maximum depth of the tree. Default is 2.
tree.args
(named list)
Further arguments for party::ctree()
.
predict()
Predict new data with the tree. See also predict.TreeSurrogate
TreeSurrogate$predict(newdata, type = "prob", ...)
newdata
data.frame
Data to predict on.
type
Prediction type.
...
Further arguments passed to predict()
.
clone()
The objects of this class are cloneable with this method.
TreeSurrogate$clone(deep = FALSE)
deep
Whether to make a deep clone.
Craven, M., & Shavlik, J. W. (1996). Extracting tree-structured representations of trained networks. In Advances in neural information processing systems (pp. 24-30).
predict.TreeSurrogate plot.TreeSurrogate
For the tree implementation
partykit::ctree()
library("randomForest") # Fit a Random Forest on the Boston housing data set data("Boston", package = "MASS") rf <- randomForest(medv ~ ., data = Boston, ntree = 50) # Create a model object mod <- Predictor$new(rf, data = Boston[-which(names(Boston) == "medv")]) # Fit a decision tree as a surrogate for the whole random forest dt <- TreeSurrogate$new(mod) # Plot the resulting leaf nodes plot(dt) # Use the tree to predict new data predict(dt, Boston[1:10, ]) # Extract the results dat <- dt$results head(dat) # It also works for classification rf <- randomForest(Species ~ ., data = iris, ntree = 50) X <- iris[-which(names(iris) == "Species")] mod <- Predictor$new(rf, data = X, type = "prob") # Fit a decision tree as a surrogate for the whole random forest dt <- TreeSurrogate$new(mod, maxdepth = 2) # Plot the resulting leaf nodes plot(dt) # If you want to visualize the tree directly: plot(dt$tree) # Use the tree to predict new data set.seed(42) iris.sample <- X[sample(1:nrow(X), 10), ] predict(dt, iris.sample) predict(dt, iris.sample, type = "class") # Extract the dataset dat <- dt$results head(dat)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.