knitr::opts_chunk$set( collapse = TRUE, comment = "#>", warning = FALSE )
Let's start by loading the parttree package alongside rpart, which comes bundled with the base R installation and is what we'll use for fitting our decision trees (at least, to start with). For the basic examples that follow, I'll use the well-known Palmer Penguins dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV here.
library(rpart) # For fitting decisions trees library(parttree) # This package (will automatically load ggplot2 too) theme_set(theme_linedraw()) # install.packages("palmerpenguins") data("penguins", package = "palmerpenguins") head(penguins)
Say we are interested in predicting the penguins species as a function of 1) flipper length and 2) bill length. We can visualize these relationships as a simple scatter plot prior to doing any formal modeling.
p = ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_point(aes(col = species)) p
Recasting in terms of a decision tree is easily done (e.g., with rpart
).
However, visualizing the resulting tree predictions against the raw data is hard
to do out of the box and this where parttree enters the fray. The main
function that users will interact with is geom_parttree()
, which provides a
new geom layer for ggplot2 objects.
## Fit a decision tree using the same variables as the above plot tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) ## Visualize the tree partitions by adding it to our plot with geom_parttree() p + geom_parttree(data = tree, aes(fill=species), alpha = 0.1) + labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.")
Trees with continuous independent variables are also supported. However, I
recommend adjusting the plot fill aesthetic since your model will likely
partition the data into intervals that don't match up exactly with the raw data.
The easiest way to do this is by setting your colour and fill aesthetic together
as part of the same scale_colour_*
call.
tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins) ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) + geom_point(aes(col = body_mass_g)) + scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together
Currently, the package works with decision trees created by the
rpart and
partykit packages.
Moreover, it supports other front-end modes that call rpart::rpart()
as
the underlying engine; in particular the
tidymodels (parsnip
or workflows) and
mlr3 packages. Here's a quick example with
parsnip.
set.seed(123) ## For consistent jitter library(parsnip) library(titanic) ## Just for a different data set titanic_train$Survived = as.factor(titanic_train$Survived) ## Build our tree using parsnip (but with rpart as the model engine) ti_tree = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Survived ~ Pclass + Age, data = titanic_train) ## Plot the data and model partitions titanic_train |> ggplot(aes(x=Pclass, y=Age)) + geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) + geom_jitter(aes(col=Survived), alpha=0.7)
Underneath the hood, geom_parttree()
is calling the companion parttree()
function, which coerces the rpart tree object into a data frame that is
easily understood by ggplot2. For example, consider again our first "tree"
model from earlier. Here's the print output of the raw model.
tree
And here's what we get after we feed it to parttree()
.
parttree(tree)
Again, the resulting data frame is designed to be amenable to a ggplot2 geom
layer, with columns like xmin
, xmax
, etc. specifying aesthetics that
ggplot2 recognises. (Fun fact: geom_parttree()
is really just a thin
wrapper around geom_rect()
.) The goal of the package is to abstract away these
kinds of details
from the user, so we can just specify geom_parttree()
— with a valid
tree object as the data input — and be done with it. However, while this
generally works well, it can sometimes lead to unexpected behaviour in terms of
plot orientation. That's because it's hard to guess ahead of time what the user
will specify as the x and y variables (i.e. axes) in their other plot layers. To
see what I mean, let's redo our penguin plot from earlier, but this time switch
the axes in the main ggplot()
call.
## First, redo our first plot but this time switch the x and y variables p3 = ggplot( data = penguins, aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! ) + geom_point(aes(col = species)) ## Add on our tree (and some preemptive titling..) p3 + geom_parttree(data = tree, aes(fill = species), alpha = 0.1) + labs( title = "Oops!", subtitle = "Looks like a mismatch between our x and y axes..." )
As was the case here, this kind of orientation mismatch is normally (hopefully)
pretty easy to recognize. To fix, we can use the flipaxes = TRUE
argument to
flip the orientation of the geom_parttree
layer.
p3 + geom_parttree( data = tree, aes(fill = species), alpha = 0.1, flipaxes = TRUE ## Flip the orientation ) + labs(title = "That's better")
While the package has been primarily designed to work with ggplot2, the
parttree()
infrastructure can also be used to generate plots with base
graphics. Here, the ctree()
function from partykit is used for fitting
the tree.
library(partykit) ## CTree and corresponding partition ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) pt = parttree(ct) ## Color palette pal = palette.colors(4, "R4")[-1] ## Maximum/minimum for plotting range as rect() does not handle Inf well m = 1000 ## scatter plot() with added rect() plot( bill_length_mm ~ flipper_length_mm, data = penguins, col = pal[species], pch = 19 ) rect( pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax), col = adjustcolor(pal, alpha.f = 0.1)[pt$species] )
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.