knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.path = "man/figures/README-", out.width = "100%", message = FALSE )
Regression and classification trees (e.g. from packages like partykit
or rpart
) are a very powerful set of statistical learning algorithms.
Nevertheless each tree package has its own way of representing and storing the trees, usually as a nested recursive list with attributes. This makes it very hard to interact with them.
This package provides an interface to convert tree objects from various packages into a "tidy" data frame, with a row for each node showing its defining set of rules and its characteristics.
You can install the last version of tidytrees
with
# install.packages("devtools") devtools::install_github("bakaburg1/tidytrees")
tidytrees
exposes the generic function tidy_tree
which has a method for various tree objects (see ?tidy_tree
for the list supported methods). The output is a tibble with a row of each tree node. For each node the relative rules are reported, plus other information like the node id, the number of observations related to the node in the data from which the model is derived, the depth of the node in the tree.
library(tidytrees) library(partykit) library(rpart) # The function works with partykit... model <- ctree(Sepal.Width ~ Species + Sepal.Length, data = iris) tidy_tree(model) # ... and with rpart trees (more models to come) model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris) tidy_tree(model)
The rules can optionally be rendered in a R compatible format, for easy use as data filters, or as list of rules.
library(tidytrees) library(dplyr) library(rpart) model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris) # Evaluation friendly rules out <- tidy_tree(model, eval_ready = T) iris %>% filter(eval(str2expression(out$rule[3]))) %>% str # Rules as vectors out <- tidy_tree(model, rule_as_text = F) out$rule[3] # Both out <- tidy_tree(model, rule_as_text = F, eval_ready = T) out$rule[3]
Tree models tend to create explicit, nested rules with redundant components, useful to retain the whole branching information. The package can simplify such rules in order to make them more human-friendly while keeping the minimal necessary set of conditions to identify a partition. The simplified rules are ordered alphabetically to group conditions on the same variables together.
library(tidytrees) library(dplyr) library(rpart) model <- rpart(Sepal.Length ~ Species + Sepal.Width, data = iris) # Full rules tidy_tree(model)$rule[5:9] # Simplified rules tidy_tree(model, simplify_rules = T)$rule[5:9] # It works also on a list of conditions tidy_tree(model, rule_as_text = F, simplify_rules = T)$rule[5:9] # Can be applied to previously created rules rules <- tidy_tree(model)$rule[5:9] simplify_rules(rules)
The output contains optionally the predicted value in the node and estimation intervals, with the possibility to chose the interval coverage (default = 95%).
library(tidytrees) library(dplyr) library(rpart) # Intervals for continuous... model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris) tidy_tree(model, add_interval = T, interval_level = .89) # ... and discrete outcomes model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris) tidy_tree(model, add_interval = T, interval_level = .89)
The default intervals are based on the normal approximation for continuous values and on binom.test()
for discrete ones. But the estimation function is pluggable, so users can provide their own.
library(tidytrees) library(dplyr) library(rpart) # Quantile intervals for continuous outcomes model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris) tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) { data.frame( estimate = median(values), conf.low = quantile(values, (1 - interval_level) / 2), conf.high = quantile(values, .5 + interval_level/2) ) }) # Bayesian regularized credibility intervals for discrete outcomes model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris) tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) { table(values) %>% lapply(function(cases) { qbeta( c(.5, (1 - interval_level) / 2, .5 + interval_level/2), cases + 1.1, length(values) - cases + 1.1 ) %>% matrix(nrow = 1) %>% as.data.frame() %>% setNames(c('estimate', 'cred.low', 'cred.high')) }) %>% bind_rows() })
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.