View source: R/xgb.plot.multi.trees.R
xgb.plot.multi.trees | R Documentation |
Visualization of the ensemble of trees as a single collective unit.
xgb.plot.multi.trees(
model,
feature_names = NULL,
features_keep = 5,
plot_width = NULL,
plot_height = NULL,
render = TRUE,
...
)
model |
produced by the |
feature_names |
names of each feature as a |
features_keep |
number of features to keep in each position of the multi trees. |
plot_width |
width in pixels of the graph to produce |
plot_height |
height in pixels of the graph to produce |
render |
a logical flag for whether the graph should be rendered (see Value). |
... |
currently not used |
This function tries to capture the complexity of a gradient boosted tree model in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. The goal is to improve the interpretability of a model generally seen as black box.
Note: this function is applicable to tree booster-based models only.
It takes advantage of the fact that the shape of a binary tree is only defined by its depth (therefore, in a boosting model, all trees have similar shape).
Moreover, the trees tend to reuse the same features.
The function projects each tree onto one, and keeps for each position the
features_keep
first features (based on the Gain per feature measure).
This function is inspired by this blog post: https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/
When render = TRUE
:
returns a rendered graph object which is an htmlwidget
of class grViz
.
Similar to ggplot objects, it needs to be printed to see it when not running from command line.
When render = FALSE
:
silently returns a graph object which is of DiagrammeR's class dgr_graph
.
This could be useful if one wants to modify some of the graph attributes
before rendering the graph with render_graph
.
data(agaricus.train, package='xgboost')
## Keep the number of threads to 2 for examples
nthread <- 2
data.table::setDTthreads(nthread)
bst <- xgboost(
data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
eta = 1, nthread = nthread, nrounds = 30, objective = "binary:logistic",
min_child_weight = 50, verbose = 0
)
p <- xgb.plot.multi.trees(model = bst, features_keep = 3)
print(p)
## Not run:
# Below is an example of how to save this plot to a file.
# Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
library(DiagrammeR)
gr <- xgb.plot.multi.trees(model=bst, features_keep = 3, render=FALSE)
export_graph(gr, 'tree.pdf', width=1500, height=600)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.