Our package RevoEnhancements
on github now contains an enhancement to RevoScaler
that enables plotting of rxDtree
and rxDForest
objects.
First, load the required packages:
library(RevoScaleR) library(RevoEnhancements) library(ggplot2) rxOptions( computeContext=RxLocalParallel(), reportProgress=0 )
The Revolution function to create a decision tree is rxDTree
. We use rxDTree
to fit a partition tree of the diamonds
data, restricting the maximum tree depth to 3, simply to make the plot stand out better.
### Create and plot an rxDTree object frm <- formulaExpand(price ~ ., data=diamonds) fit <- rxDTree(frm, diamonds, maxDepth=3)
Since the package RevoEnhancements
exports a method for plot.rxDTree
, we simply have to call plot(fit)
to visualise the tree:
par(mar=c(0.2, 0.2, 1.5, 0.2)) par(mai=rep(0.1, 4)) plot(fit)
The function plot.rxDTree
takes additional arguments to control the plot and text characteristics:
plotArgs
should be a list and gets passed to plot.rpart
. Use this to control plot axes, titles, etc.textArg
is also a list and gets passed to text.rpart
. Use this to control the text on the plot, e.g. colour (col
) and character size (cex
).For example, to set the text colour to blue and the character expansion to 0.7, use the textArgs
argument:
plot(fit, textArgs=list(col="blue", cex=0.7))
To also set the plot title, use the plotArgs
argument:
plot(fit, textArgs=list(col="blue", cex=0.7), plotArgs=list(main="Decision tree"))
For comparison we fit a decision forest on the same diamonds
data. This creates a rxDForest
object. This object is a list and the individual trees are contained in the forest
element.
### Create and plot an rxDForest object fit <- rxDForest(frm, diamonds, maxDepth=3) forest <- fit$forest
Now we can plot the individual trees. Since each individual tree of an rxDForest
object is itself and rxDTree
object, the plotting is identical to what we illustrated earlier:
par(mar=c(0.2, 0.2, 1.5, 0.2)) par(mai=rep(0.1, 4)) plot(forest[[1]], textArgs=list(col="blue", cex=0.7))
Again, you can take control of the plot with the plotArgs
and textArgs
arguments:
plot(forest[[1]], textArgs=list(col="blue", cex=0.7), plotArgs=list(main="Decision forest: Tree 1")) plot(forest[[2]], textArgs=list(col="blue", cex=0.7), plotArgs=list(main="Decision forest: Tree 2"))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.