knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 7, fig.height = 2.5 )
This vignette illustrates how to interpret multiclass models with endoR. We will use the iris data for this purpose: the species is being predicted (= target) using the length and width of petals and sepals (= features).
library(endoR) library(randomForest) library(tidyverse) library(caret) library(ggpubr) library(ggraph)
data(iris) summary(iris)
set.seed(1313) mod <- randomForest(Species ~ ., data = iris) mod
We will only prune (= TRUE) and discretize decisions in K = 3 levels.
endo_setosa <- model2DE(model = mod, model_type = 'rf' , data = select(iris, -Species), target = iris$Species , classPos = 'setosa' # our focal class , filter = FALSE # we filter in K = 3 categories the numerical features , discretize = TRUE, K = 3)
Plants from the setosa species have small petals and narrow, long sepals.
plotFeatures(endo_setosa, levels_order = c('Low', 'Medium', 'High')) # The warnings du to the font are due to Windows .. no worries. plotNetwork(endo_setosa, hide_isolated_nodes = FALSE, layout = 'fr')
The only interaction used for predictions is the one of sepal length (High =
long) and width (Low = narrow). We can also use ggplot2 theme
to format the
plot, e.g., put the legend boxes next to each other.
plotNetwork(endo_setosa, hide_isolated_nodes = TRUE, layout = 'fr')+ theme(legend.box = "horizontal")
This time we will filter decisions based on their importance to trim the network (filter = TRUE). We will use min_imp = 0.5 to keep at least all decisions with an importance > 0.5*the best importance (the lower min_imp, the slighter the filtering).
endo_versicolor <- model2DE(model = mod, model_type = 'rf' , data = select(iris, -Species), target = iris$Species , classPos = 'versicolor' , K = 3, discretize = TRUE , filter = TRUE, min_imp = 0.5)
The petal's proportions are intermediary compared to the setosa and virginica species.
plotFeatures(endo_versicolor, levels_order = c('Low', 'Medium', 'High'))
Sepals have a narrow-intermediary width and long-intermediary length (= if wide and small then it's not a versicolor, as seen on the network).
plotNetwork(endo_versicolor, hide_isolated_nodes = FALSE, layout = 'fr' # we show only edges that connect 3 nodes max -> removes edges with # lowest importances - for longer paths = more complex network, # you can increase path_length , path_length = 3 )+ scale_edge_alpha(range = c(0.8,1))+ theme(legend.box = "horizontal")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.