View source: R/xgb.plot.multi.trees.R
| xgb.plot.multi.trees | R Documentation |
Visualization of the ensemble of trees as a single collective unit.
xgb.plot.multi.trees(
model,
features_keep = 5,
plot_width = NULL,
plot_height = NULL,
render = TRUE,
...
)
model |
Object of class |
features_keep |
Number of features to keep in each position of the multi trees, by default 5. |
plot_width, plot_height |
Width and height of the graph in pixels.
The values are passed to |
render |
Should the graph be rendered or not? The default is |
... |
Not used. Some arguments that were part of this function in previous XGBoost versions are currently deprecated or have been renamed. If a deprecated or renamed argument is passed, will throw a warning (by default) and use its current equivalent instead. This warning will become an error if using the 'strict mode' option. If some additional argument is passed that is neither a current function argument nor a deprecated or renamed argument, a warning or error will be thrown depending on the 'strict mode' option. Important: |
Note that this function does not work with models that were fitted to categorical data.
This function tries to capture the complexity of a gradient boosted tree model in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. The goal is to improve the interpretability of a model generally seen as black box.
Note: this function is applicable to tree booster-based models only.
It takes advantage of the fact that the shape of a binary tree is only defined by its depth (therefore, in a boosting model, all trees have similar shape).
Moreover, the trees tend to reuse the same features.
The function projects each tree onto one, and keeps for each position the
features_keep first features (based on the Gain per feature measure).
This function is inspired by this blog post: https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/
Rendered graph object which is an htmlwidget of ' class grViz. Similar to
"ggplot" objects, it needs to be printed when not running from the command
line.
data(agaricus.train, package = "xgboost")
## Keep the number of threads to 2 for examples
nthread <- 2
data.table::setDTthreads(nthread)
model <- xgboost(
agaricus.train$data, factor(agaricus.train$label),
nrounds = 30,
verbosity = 0L,
nthreads = nthread,
max_depth = 15,
learning_rate = 1,
min_child_weight = 50
)
p <- xgb.plot.multi.trees(model, features_keep = 3)
print(p)
# Below is an example of how to save this plot to a file.
if (require("DiagrammeR") && require("DiagrammeRsvg") && require("rsvg")) {
fname <- file.path(tempdir(), "tree.pdf")
gr <- xgb.plot.multi.trees(model, features_keep = 3, render = FALSE)
export_graph(gr, fname, width = 1500, height = 600)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.