sgtree: Surrogate trees

View source: R/xgrove.R

sgtreeR Documentation

Surrogate trees

Description

Compute surrogate trees of different depth to explain predictive machine learning model and analyze complexity vs. explanatory power.

Usage

sgtree(model, data, maxdeps = 1:8, cparam = 0, pfun = NULL, ...)

Arguments

model

A model with corresponding predict function that returns numeric values.

data

Data that must not (!) contain the target variable.

maxdeps

Sequence of integers: Maximum depth of the trees.

cparam

Complexity parameter for growing the trees.

pfun

Optional predict function function(model, data) returning a real number. Default is the predict() method of the model.

...

Further arguments to be passed to rpart.control or the predict() method of the model.

Details

A surrogate grove is trained via gradient boosting using rpart on data with the predictions of using of the model as target variable. Note that data must not contain the original target variable!

Value

List of the results:

explanation

Matrix containing tree sizes, rules, explainability {\Upsilon} and the correlation between the predictions of the explanation and the true model.

rules

List of rules for each tree.

model

List of the rpart models.

Author(s)

gero.szepannek@web.de

References

  • Szepannek, G. and Laabs, B.H. (2023): Can’t see the forest for the trees – analyzing groves to explain random forests, Behaviormetrika, submitted.

  • Szepannek, G. and Luebke, K.(2023): How much do we see? On the explainability of partial dependence plots for credit risk scoring, Argumenta Oeconomica 50, DOI: 10.15611/aoe.2023.1.07.

Examples

library(randomForest)
library(pdp)
data(boston)
set.seed(42)
rf    <- randomForest(cmedv ~ ., data = boston)
data  <- boston[,-3] # remove target variable
maxds <- 1:7
st    <- sgtree(rf, data, maxds)
st
# rules for tree of depth 3
st$rules[["3"]]
# plot tree of depth 3
rpart.plot::rpart.plot(st$model[["3"]])


xgrove documentation built on Sept. 23, 2024, 1:06 a.m.