GbmExplainR

GbmExplainR is a package to decompose gbm predictions into feature contributions. There is also functionality to plot individual trees from the models and the route for a given observation through the tree to a terminal node. GbmExplainR works with the gbm package.

GbmExplainR is based off the treeinterpreter Python package, there a blog post on treeinterpreter here.

Decompose gbm predictions into feature contributions

library(gbm)
library(igraph)
library(devtools)
devtools::load_all()
#library(GbmExplainR)
set.seed(1)
N <- 1000
X1 <- runif(N)
X2 <- 2*runif(N)
X3 <- ordered(sample(letters[1:4],N,replace=TRUE),levels=letters[4:1])
X4 <- factor(sample(letters[1:6],N,replace=TRUE))
X5 <- factor(sample(letters[1:3],N,replace=TRUE))
X6 <- 3*runif(N) 
mu <- c(-1,0,1,2)[as.numeric(X3)]

SNR <- 10 # signal-to-noise ratio
Y <- X1**1.5 + 2 * (X2**.5) + mu
sigma <- sqrt(var(Y)/SNR)
Y <- Y + rnorm(N,0,sigma)

# introduce some missing values
X1[sample(1:N,size=500)] <- NA

data <- data.frame(Y=Y,X1=X1,X2=X2,X3=X3,X4=X4,X5=X5,X6=X6)

# fit initial model
gbm1 <- gbm(Y~X1+X2+X3+X4+X5+X6,        
           data=data,                  
           var.monotone=c(0,0,0,0,0,0),
           distribution="gaussian",   
           n.trees=1000,     
           shrinkage=0.05,  
           interaction.depth=3,
           bag.fraction = 0.5,
           train.fraction = 0.5)

Let's look at the predicted value from a gbm. Note this model is the first example from ?gbm.

predict(gbm1, data[1, ], n.trees = gbm1$n.trees)

For a given prediction from a gbm, the feature contributions can be extracted;

decompose_gbm_prediction(gbm = gbm1, prediction_row = data[1, ])

Notice how the feature contributions sum to give the predicted value.

These can be charted with a simple barchart;

plot_feature_contributions(feature_contributions = decompose_gbm_prediction(gbm1, data[1, ]),
                           cex.names = 0.8)

Tree structure and terminal node path

Individual trees can be plotted, and the route to a terminal node can be highlighted for a given observation;

plot_tree(gbm = gbm1, 
          tree_no = 1, 
          plot_path = data[1, ], 
          edge.label.cex = 1.2,
          vertex.label.cex = 1.2) 

Installation

Install form Github with devtools;

library(devtools)
devtools::install_github(richardangell/GbmExplainR)

Other similar works

There are other similar packages in R and Python that implement the same method for a variety of tree based models;



richardangell/GbmExplainR documentation built on May 22, 2019, 12:54 p.m.