knitr::opts_chunk$set( collapse = TRUE, comment = "#>", warning = FALSE, out.width = "90%", # fig.width = 8, # dpi = 300, asp = 0.625 )
Start by loading the parttree package alongside rpart, which comes bundled with the base R installation. For the basic examples that follow, I'll use the well-known Palmer Penguins dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV here.
library(parttree) # This package library(rpart) # For fitting decisions trees # install.packages("palmerpenguins") data("penguins", package = "palmerpenguins") head(penguins)
data.table::setDTthreads(2)
Dataset in hand, let's say that we are interested in predicting penguin species as a function of 1) flipper length and 2) bill length. We could model this as a simple decision tree:
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) tree
Like most tree-based frameworks, rpart comes with a default plot
method
for visualizing the resulting node splits.
plot(tree, compress = TRUE) text(tree, use.n = TRUE)
While this is okay, I don't feel that it provides much intuition about the model's prediction on the scale of the actual data. In other words, what I'd prefer to see is: How has our tree partitioned the original penguin data?
This is where parttree enters the fray. The package is named for its primary
workhorse function parttree()
, which extracts all of the information needed
to produce a nice plot of our tree partitions alongside the original data.
ptree = parttree(tree) plot(ptree)
Et voila! Now we can clearly see how our model has divided up the Cartesian space of the data. Gentoo penguins typically have longer flippers than Chinstrap or Adelie penguins, while the latter have the shortest bills.
From the perspective of the end-user, the ptree
parttree object is not all
that interesting in of itself. It is simply a data frame that contains the basic
information needed for our plot (partition coordinates, etc.). You can think of
it as a helpful intermediate object on our way to the visualization of interest.
# See also `attr(ptree, "parttree")`
ptree
Speaking of visualization, underneath the hood plot.parttree
calls the
powerful
tinyplot
package. All of the latter's various customization arguments can be passed on to
our parttree
plot to make it look a bit nicer. For example:
plot(ptree, pch = 16, palette = "classic", alpha = 0.75, grid = TRUE)
In addition to discrete classification problems, parttree also supports regression trees with continuous independent variables.
tree_cont = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data = penguins) tree_cont |> parttree() |> plot(pch = 16, palette = "viridis")
Alongside the rpart model
objects that we have been working with thus far, parttree also supports
decision trees created by the
partykit package. Here we
see how the latter's ctree
(conditional inference tree) algorithm yields a
slightly more sophisticated partitioning that the former's default.
library(partykit) ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) |> parttree() |> plot(pch = 16, palette = "classic", alpha = 0.5)
parttree also supports a variety of "frontend" modes that call
rpart::rpart()
as the underlying engine. This includes packages from both the
mlr3 and
tidymodels
(parsnip
or workflows)
ecosystems. Here is a quick demonstration using parsnip, where we'll also
pull in a different dataset just to change things up a little.
set.seed(123) ## For consistent jitter library(parsnip) library(titanic) ## Just for a different data set titanic_train$Survived = as.factor(titanic_train$Survived) ## Build our tree using parsnip (but with rpart as the model engine) ti_tree = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Survived ~ Pclass + Age, data = titanic_train) ## Now pass to parttree and plot ti_tree |> parttree() |> plot(pch = 16, jitter = TRUE, palette = "dark", alpha = 0.7)
The default plot.parttree
method produces a base graphics plot. But we also
support ggplot2 via with a dedicated geom_parttree()
function. Here we
demonstrate with our initial classification tree from earlier.
library(ggplot2) theme_set(theme_linedraw()) ## re-using the tree model object from above... ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_point(aes(col = species)) + geom_parttree(data = tree, aes(fill=species), alpha = 0.1)
Compared to the "native" plot.parttree
method, note that the ggplot2
workflow requires a few tweaks:
geom_point()
).geom_parttree()
accepts the tree object itself, not the result of parttree()
.^[This is because geom_parttree(data = <tree>)
calls parttree(<tree>)
internally.] Continuous regression trees can also be drawn with geom_parttree
. However, I
recommend adjusting the plot fill aesthetic since your model will likely
partition the data into intervals that don't match up exactly with the raw data.
The easiest way to do this is by setting your colour and fill aesthetic together
as part of the same scale_colour_*
call.
## re-using the tree_cont model object from above... ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_parttree(data = tree_cont, aes(fill=body_mass_g), alpha = 0.3) + geom_point(aes(col = body_mass_g)) + scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together
As we have already said, geom_parttree()
calls the companion parttree()
function internally, which coerces the rpart tree object into a data frame
that is easily understood by ggplot2. For example, consider our initial
"ptree" object from earlier.
# ptree = parttree(tree)
ptree
Again, the resulting data frame is designed to be amenable to a ggplot2 geom
layer, with columns like xmin
, xmax
, etc. specifying aesthetics that
ggplot2 recognizes. (Fun fact: geom_parttree()
is really just a thin
wrapper around geom_rect()
.) The goal of parttree is to abstract away
these kinds of details from the user, so that they can just specify
geom_parttree()
—with a valid tree object as the data input—and be
done with it. However, while this generally works well, it can sometimes lead to
unexpected behaviour in terms of plot orientation. That's because it's hard to
guess ahead of time what the user will specify as the x and y variables (i.e.
axes) in their other ggplot2 layers.^[The default plot.partree
method
doesn't have this problem because it assigns the x and y variables for both the
partitions and raw data points as part of the same function call.] To see what I
mean, let's redo our penguin plot from earlier, but this time switch the axes in
the main ggplot()
call.
## First, redo our first plot but this time switch the x and y variables p3 = ggplot( data = penguins, aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! ) + geom_point(aes(col = species)) ## Add on our tree (and some preemptive titling..) p3 + geom_parttree(data = tree, aes(fill = species), alpha = 0.1) + labs( title = "Oops!", subtitle = "Looks like a mismatch between our x and y axes..." )
As was the case here, this kind of orientation mismatch is normally (hopefully)
pretty easy to recognize. To fix, we can use the flip = TRUE
argument to
flip the orientation of the geom_parttree
layer.
p3 + geom_parttree( data = tree, aes(fill = species), alpha = 0.1, flip = TRUE ## Flip the orientation ) + labs(title = "That's better")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.