This vignette illustrates how to interpret models with endoR and regularization. We will use the titanic data for this purpose: the survival of passengers is being predicted (= target) using information on passengers (e.g., gender, age, etc = features).
library(tidyverse) #> ── Attaching packages ──────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.1 ── #> ✔ ggplot2 3.3.6 ✔ purrr 0.3.4 #> ✔ tibble 3.1.7 ✔ dplyr 1.0.9 #> ✔ tidyr 1.2.0 ✔ stringr 1.4.0 #> ✔ readr 2.1.2 ✔ forcats 0.5.1 #> ── Conflicts ─────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ── #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() library(stringr) library("ggpubr") library(igraph) #> #> Attaching package: 'igraph' #> The following objects are masked from 'package:dplyr': #> #> as_data_frame, groups, union #> The following objects are masked from 'package:purrr': #> #> compose, simplify #> The following object is masked from 'package:tidyr': #> #> crossing #> The following object is masked from 'package:tibble': #> #> as_data_frame #> The following objects are masked from 'package:stats': #> #> decompose, spectrum #> The following object is masked from 'package:base': #> #> union library(ggraph) library("inTrees") library(ranger) #> Warning: package 'ranger' was built under R version 4.2.1 library(parallel) library(caret) #> Loading required package: lattice #> #> Attaching package: 'caret' #> The following object is masked from 'package:purrr': #> #> lift library(endoR) #> Warning: replacing previous import 'rlang:::=' by 'data.table:::=' when loading 'endoR' #> Warning: replacing previous import 'data.table::last' by 'dplyr::last' when loading 'endoR' #> Warning: replacing previous import 'data.table::first' by 'dplyr::first' when loading 'endoR' #> Warning: replacing previous import 'data.table::between' by 'dplyr::between' when loading 'endoR' #> Warning: replacing previous import 'dplyr::union' by 'igraph::union' when loading 'endoR' #> Warning: replacing previous import 'rlang::is_named' by 'igraph::is_named' when loading 'endoR' #> Warning: replacing previous import 'dplyr::as_data_frame' by 'igraph::as_data_frame' when loading 'endoR' #> Warning: replacing previous import 'dplyr::groups' by 'igraph::groups' when loading 'endoR' #> Registered S3 method overwritten by 'randomForest': #> method from #> plot.margin RRF #> Warning: replacing previous import 'ggplot2::margin' by 'randomForest::margin' when loading 'endoR' #> Warning: replacing previous import 'dplyr::combine' by 'randomForest::combine' when loading 'endoR' #> Warning: replacing previous import 'randomForest::importance' by 'ranger::importance' when loading 'endoR' #> Warning: replacing previous import 'igraph::decompose' by 'stats::decompose' when loading 'endoR' #> Warning: replacing previous import 'dplyr::filter' by 'stats::filter' when loading 'endoR' #> Warning: replacing previous import 'dplyr::lag' by 'stats::lag' when loading 'endoR' #> Warning: replacing previous import 'igraph::spectrum' by 'stats::spectrum' when loading 'endoR' #> Warning: replacing previous import 'dplyr::slice' by 'xgboost::slice' when loading 'endoR' library(data.table) #> data.table 1.14.2 using 4 threads (see ?getDTthreads). Latest news: r-datatable.com #> #> Attaching package: 'data.table' #> The following objects are masked from 'package:dplyr': #> #> between, first, last #> The following object is masked from 'package:purrr': #> #> transpose library(clustermq) #> Warning: package 'clustermq' was built under R version 4.2.1 #> * Option 'clustermq.scheduler' not set, defaulting to 'LOCAL' #> --- see: https://mschubert.github.io/clustermq/articles/userguide.html#configuration
sessionInfo() #> R version 4.2.0 (2022-04-22 ucrt) #> Platform: x86_64-w64-mingw32/x64 (64-bit) #> Running under: Windows 10 x64 (build 22000) #> #> Matrix products: default #> #> locale: #> [1] LC_COLLATE=English_Europe.utf8 LC_CTYPE=English_Europe.utf8 LC_MONETARY=English_Europe.utf8 LC_NUMERIC=C #> [5] LC_TIME=English_Europe.utf8 #> #> attached base packages: #> [1] parallel stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] clustermq_0.8.95.3 data.table_1.14.2 endoR_0.1.0 caret_6.0-92 lattice_0.20-45 ranger_0.14.1 inTrees_1.3 #> [8] ggraph_2.0.5 igraph_1.3.2 ggpubr_0.4.0 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.9 purrr_0.3.4 #> [15] readr_2.1.2 tidyr_1.2.0 tibble_3.1.7 ggplot2_3.3.6 tidyverse_1.3.1 #> #> loaded via a namespace (and not attached): #> [1] colorspace_2.0-3 ggsignif_0.6.3 ellipsis_0.3.2 class_7.3-20 fs_1.5.2 rstudioapi_0.13 #> [7] listenv_0.8.0 farver_2.1.0 graphlayouts_0.8.0 ggrepel_0.9.1 prodlim_2019.11.13 fansi_1.0.3 #> [13] lubridate_1.8.0 xml2_1.3.3 codetools_0.2-18 splines_4.2.0 arules_1.7-3 knitr_1.39 #> [19] polyclip_1.10-0 jsonlite_1.8.0 pROC_1.18.0 broom_0.8.0 dbplyr_2.2.0 ggforce_0.3.3 #> [25] compiler_4.2.0 httr_1.4.3 backports_1.4.1 assertthat_0.2.1 Matrix_1.4-1 fastmap_1.1.0 #> [31] cli_3.3.0 tweenr_1.0.2 htmltools_0.5.2 tools_4.2.0 gtable_0.3.0 glue_1.6.2 #> [37] reshape2_1.4.4 Rcpp_1.0.8.3 carData_3.0-5 cellranger_1.1.0 vctrs_0.4.1 nlme_3.1-157 #> [43] iterators_1.0.14 timeDate_3043.102 xfun_0.31 gower_1.0.0 globals_0.15.0 rvest_1.0.2 #> [49] lifecycle_1.0.1 rstatix_0.7.0 future_1.26.1 MASS_7.3-56 scales_1.2.0 ipred_0.9-13 #> [55] tidygraph_1.2.1 hms_1.1.1 yaml_2.3.5 gridExtra_2.3 rpart_4.1.16 stringi_1.7.6 #> [61] highr_0.9 randomForest_4.7-1.1 foreach_1.5.2 hardhat_1.1.0 lava_1.6.10 rlang_1.0.2 #> [67] pkgconfig_2.0.3 evaluate_0.15 recipes_0.2.0 tidyselect_1.1.2 parallelly_1.32.0 gbm_2.1.8 #> [73] plyr_1.8.7 magrittr_2.0.3 R6_2.5.1 generics_0.1.2 DBI_1.1.2 pillar_1.7.0 #> [79] haven_2.5.0 withr_2.5.0 survival_3.3-1 abind_1.4-5 nnet_7.3-17 future.apply_1.9.0 #> [85] modelr_0.1.8 crayon_1.5.1 car_3.0-13 xgboost_1.6.0.1 utf8_1.2.2 tzdb_0.3.0 #> [91] RRF_1.9.4 rmarkdown_2.14 viridis_0.6.2 grid_4.2.0 readxl_1.4.0 ModelMetrics_1.2.2.2 #> [97] reprex_2.0.1 digest_0.6.29 xtable_1.8-4 stats4_4.2.0 munsell_0.5.0 viridisLite_0.4.0
summary(titanic) #> gender age class embarked country fare sibsp #> female: 489 Min. : 0.1667 1st :324 Belfast : 197 England :1125 Min. : 0.000 Min. :0.0000 #> male :1718 1st Qu.:22.0000 2nd :284 Cherbourg : 271 United States: 264 1st Qu.: 0.000 1st Qu.:0.0000 #> Median :29.0000 3rd :709 Queenstown : 123 Ireland : 137 Median : 7.151 Median :0.0000 #> Mean :30.4363 deck crew : 66 Southampton:1616 Sweden : 105 Mean : 19.992 Mean :0.2959 #> 3rd Qu.:38.0000 engineering crew:324 X : 81 3rd Qu.: 21.000 3rd Qu.:0.0000 #> Max. :74.0000 restaurant staff: 69 Lebanon : 71 Max. :512.061 Max. :8.0000 #> victualling crew:431 (Other) : 424 #> parch survived #> Min. :0.0000 no :1496 #> 1st Qu.:0.0000 yes: 711 #> Median :0.0000 #> Mean :0.2284 #> 3rd Qu.:0.0000 #> Max. :9.0000 #>
Out of the 2207 passengers, 711 survived and 1496 perished.
summary(titanic$survived) #> no yes #> 1496 711
Because of the target imbalance, we will use sample weights in the RF model so that as many survivors and non-survivors are used to fit each tree.
n_yes <- sum(titanic$survived == 'yes') n_samp <- length(titanic$survived) samp_weight <- round(ifelse(titanic$survived == 'yes', 1-n_yes/n_samp, n_yes/n_samp), digits = 2) summary(as.factor(samp_weight)) #> 0.32 0.68 #> 1496 711
set.seed(1313) titanic_rf <- ranger(x = titanic %>% select(-survived), y = titanic$survived , case.weights = samp_weight) titanic_rf #> Ranger result #> #> Call: #> ranger(x = titanic %>% select(-survived), y = titanic$survived, case.weights = samp_weight) #> #> Type: Classification #> Number of trees: 500 #> Sample size: 2207 #> Number of independent variables: 8 #> Mtry: 2 #> Target node size: 1 #> Variable importance mode: none #> Splitrule: gini #> OOB prediction error: 19.30 %
It's not a very good model (about 1/3 of the survivors are mis-classified), but will be sufficient for the tutorial.
titanic_rf$confusion.matrix #> predicted #> true no yes #> no 1320 176 #> yes 250 461
The function draws times = 5
sample bootstraps with replacement (by default,
p = 0.5
is the fraction of samples drawn). One can use sample_weight
to
change the probability of samples to be drawn - this is useful for imbalanced
data.
The function will first extract decisions from the model and discretize
variables. Then, the pruning and calculation of the decision-wide feature and
interaction importances are performed on each bootstrap. It is advised to run
the function in parallel to accelerate it (in_parallel = TRUE
with, by default,
n_cores = parallel::detectCores()-1
).
rules <- model2DE_resampling(model = titanic_rf, model_type = 'ranger' , data = titanic %>% select(-survived) , target = titanic$survived, classPos = 'yes' , times = 5 , sample_weight = samp_weight , discretize = TRUE, K = 2 , prune = TRUE, maxDecay = 0.05, typeDecay = 2 , filter = FALSE , in_parallel = TRUE, n_cores = 2 ) #> Extract rules... #> Discretise data #> Discretise rules #> Initiate parallelisation ... #> Calculate metrics ... #> Pruning ... #> Generate additional decisions ... #> Initiate parallelisation ... #> Calculate metrics ... #> Pruning ... #> Generate additional decisions ... #> Initiate parallelisation ... #> Calculate metrics ... #> Pruning ... #> Generate additional decisions ... #> Initiate parallelisation ... #> Calculate metrics ... #> Pruning ... #> Generate additional decisions ... #> Initiate parallelisation ... #> Calculate metrics ... #> Pruning ... #> Generate additional decisions ...
Stability selection consists in selecting the decisions that were the most
important across all bootstraps. It depends on the parameter alpha
= expected
number of false positive decisions. Regardless of alpha, the feature and
interaction importances will be much higher for true positive features and
interactions than for false positive ones. Hence, we can set higher values of
alpha
to increase the number of recovered true features/interactions (i.e.,
get a higher recall).
Let's have a look at the effect of alpha on the number of stable decisions: we first compute the stable decision ensembles for various alpha values.
alphas <- evaluateAlpha(rules = rules, alphas = c(1:5, 7, 10) , data = rules$data) #> 15.21841 rules per sub-sample selected. 9 decisions in >= 3.5 subsets. #> 21.52208 rules per sub-sample selected. 13 decisions in >= 3.5 subsets. #> 26.35906 rules per sub-sample selected. 14 decisions in >= 3.5 subsets. #> 30.43682 rules per sub-sample selected. 15 decisions in >= 3.5 subsets. #> 34.0294 rules per sub-sample selected. 17 decisions in >= 3.5 subsets. #> 40.26413 rules per sub-sample selected. 21 decisions in >= 3.5 subsets. #> 48.12484 rules per sub-sample selected. 24 decisions in >= 3.5 subsets.
In the summary table, n_dec
= number of decisions and n_samp
= number of
samples that can be predicted with the stable decisions.
alphas$summary_table #> alpha n_dec n_samp #> 1 1 9 2207 #> 2 2 13 2207 #> 3 3 14 2207 #> 4 4 15 2207 #> 5 5 17 2207 #> 6 7 21 2207 #> 7 10 24 2207
de_final <- stabilitySelection(rules = rules, alpha_error = 3) #> 26.35906 rules per sub-sample selected. 14 decisions in >= 3.5 subsets.
We can have a look at the selected rules: by default, we have selected stable
decisions with pi=0.7
, the minimal fraction of bootstraps in which a decision
should have been important do be selected as stable. Hence, when looking at
decisions, we must subset the decisions that were important in
inN >= pi*times = 0.7*5 = 7
.
Note that you can find pi
in rules$parameters['pi_thr']
.
The inTrees::presentRules()
function formats the table to include the feature
names.
de_final$rules_summary %>% subset(inN >= .7*5) %>% presentRules(colN = colnames(rules$data)) %>% head #> condition inN len support err pred imp imp_sd n n_sd #> 1: gender__female>0.5 5 1 0.2988214 0.33022598 0.8612167 0.09854184 0.013524617 22.914378 5.405518 #> 2: gender__female>0.5 & class__1st>0.5 5 2 0.1062557 0.08857603 0.9804078 0.08640382 0.004850086 8.148616 1.545313 #> 3: gender__female>0.5 & embarked__Cherbourg>0.5 5 2 0.0774252 0.15246357 0.9603145 0.05341269 0.005265844 11.629315 1.854911 #> 4: gender__female>0.5 & class__2nd>0.5 5 2 0.0743427 0.21159270 0.9342623 0.04315968 0.013847258 28.510079 6.249040 #> 5: class__1st>0.5 5 1 0.1858568 0.39719625 0.7941215 0.03620932 0.006741799 21.664774 2.313154 #> 6: gender__male>0.5 5 1 0.7011786 0.47903961 0.3591693 0.03243141 0.003297562 28.756003 5.650443
Let's first plot the feature importance and influence:
plotFeatures(decision_ensemble = de_final)
This plot is ugly.. we would like the influence plot to be larger than the
importance one for clarity. So, we recompute the plotFeatures with
return_all = TRUE
to get the 2 individual plots. We can also provide the
order of levels to tidy the influence plot.
p_feat <- plotFeatures(decision_ensemble = de_final, return_all = TRUE , levels_order = c('male', 'female' , 'engineering crew', 'restaurant staff', 'deck crew' , 'victualling crew' , '3rd', '2nd', '1st' , 'Belfast', 'Cherbourg', 'Queenstown', 'Southampton' , 'Low', 'Medium', 'High') ) names(p_feat) #> [1] "importance" "importance_p" "influences" "influence_p"
Note that each plot was created with ggplot2
so we can modify them as wanted;
for instance, we can change the titles.
options(repr.plot.width=12, repr.plot.height=3) ggarrange(p_feat$importance_p + labs('Importance') , p_feat$influence_p + labs('Influence') , widths = c(0.25, 0.7)) # better!
Women and children first? yes but especially if they are wealthy (the class is
one of the most important features)..
Now the network: - from the 2nd class: men had low survival chances, women had high ones - from the 1st class: everyone had high survival chances, though they were even higher for women
options(repr.plot.width=8, repr.plot.height=5) plotNetwork(de_final, hide_isolated_nodes = FALSE) #> Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family not found in Windows font database #> Warning in grid.Call(C_stringMetric, as.graphicsAnnot(x$label)): font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database
We can also change the layout of the network and format edges and nodes via
ggraph
- see layouts:
https://cran.r-project.org/package=ggraph.
To hide nodes that are not part of the network: hide_isolated_nodes = TRUE
.
options(repr.plot.width=8, repr.plot.height=5) plotNetwork(de_final, hide_isolated_nodes = TRUE , layout = 'fr')+ # I usually prefer the 'fr' layout :) scale_edge_alpha(range = c(0.8, 1)) #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database
Instead of running bootstraps one after the other, we can also run them in
parallel. For this, we will first extract decisions and discretize variables
using the preCluster()
function and then run endoR on each bootstrap with the
model2DE_cluster()
function, managed by the Q()
function from the
clustermq
R-package.
The clustermq R-package allows to run jobs in parallel
locally or on HPC environments; see its anual for configuration:
https://mschubert.github.io/clustermq/articles/userguide.html. For the tutorial,
I will run it on my computer (clustermq.scheduler = "multiprocess"
).
preclu <- preCluster(model = titanic_rf, model_type = 'ranger' , data = titanic %>% select(-survived) , target = titanic$survived, classPos = 'yes' , times = 5 # number of bootstraps , sample_weight = samp_weight # sample weight for resampling , discretize = TRUE, K = 2 , in_parallel = FALSE) #> Extract rules... #> Discretise data #> Discretise rules
Let's set the clustermq parameters:
options(clustermq.scheduler = "multiprocess")
rules <- Q(model2DE_cluster , partition = preclu$partitions , export=list(data = preclu$data , target = titanic$survived , exec = preclu$exec , classPos = 'yes' , prune = TRUE, maxDecay = 0.05, typeDecay = 2 , filter = FALSE , in_parallel = TRUE, n_cores = 1 # keep to 1 to pass CRAN check but could be higher given your resources ) , n_jobs= 2 # 2 bootstraps will be processed in parallel , pkgs=c('data.table', 'parallel', 'caret', 'stringr', 'scales', 'dplyr' , 'inTrees', 'endoR') , log_worker=FALSE ) #> Warning in (function (...) : Common data is 28.4 Mb. Recommended limit is (set by clustermq.data.warning option) #> Starting 2 processes ... #> Warning in sprintf(log_file, i): one argument not used by format '|' #> Warning in sprintf(log_file, i): one argument not used by format '|' #> Running 5 calculations (10 objs/28.4 Mb common; 1 calls/chunk) ... #> [-------------------------------------------------------------------------------------------------------------] 0% (1/2 wrk) eta: ?s [-------------------------------------------------------------------------------------------------------------] 0% (2/2 wrk) eta: ?s [=====================>---------------------------------------------------------------------------------------] 20% (2/2 wrk) eta: 18m [===========================================>-----------------------------------------------------------------] 40% (2/2 wrk) eta: 8m [================================================================>--------------------------------------------] 60% (2/2 wrk) eta: 7m [======================================================================================>----------------------] 80% (2/2 wrk) eta: 3m [=============================================================================================================] 100% (1/1 wrk) eta: 0s Master: [968.4s 0.1% CPU]; Worker: [avg 1.6% CPU, max 706.7 Mb]
Just like above, except that now data are in preclu$data
and not in the rules.
de_final <- stabilitySelection(rules = rules, alpha_error = 3) #> 26.35906 rules per sub-sample selected. 14 decisions in >= 3.5 subsets.
de_final$rules_summary %>% subset(inN >= .7*5) %>% presentRules(colN = colnames(preclu$data)) %>% head #> condition inN len support err pred imp imp_sd n n_sd #> 1: gender__female>0.5 5 1 0.2988214 0.33022598 0.8612167 0.09854184 0.013524617 22.914378 5.405518 #> 2: gender__female>0.5 & class__1st>0.5 5 2 0.1062557 0.08857603 0.9804078 0.08640382 0.004850086 8.148616 1.545313 #> 3: gender__female>0.5 & embarked__Cherbourg>0.5 5 2 0.0774252 0.15246357 0.9603145 0.05341269 0.005265844 11.629315 1.854911 #> 4: gender__female>0.5 & class__2nd>0.5 5 2 0.0743427 0.21159270 0.9342623 0.04315968 0.013847258 28.510079 6.249040 #> 5: class__1st>0.5 5 1 0.1858568 0.39719625 0.7941215 0.03620932 0.006741799 21.664774 2.313154 #> 6: gender__male>0.5 5 1 0.7011786 0.47903961 0.3591693 0.03243141 0.003297562 28.756003 5.650443
p_feat <- plotFeatures(decision_ensemble = de_final, return_all = TRUE , levels_order = c('male', 'female' , 'engineering crew', 'restaurant staff', 'deck crew' , 'victualling crew' , '3rd', '2nd', '1st' , 'Belfast', 'Cherbourg', 'Queenstown', 'Southampton' , 'Low', 'Medium', 'High') )
options(repr.plot.width=12, repr.plot.height=3) ggarrange(p_feat$importance_p + labs(title = 'Importance') , p_feat$influence_p + labs(title = 'Influence') , widths = c(0.25, 0.7)) # better!
options(repr.plot.width=8, repr.plot.height=5) plotNetwork(de_final, hide_isolated_nodes = FALSE)+ scale_edge_alpha(range = c(0.8, 1)) #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database #> Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font family not found in Windows font database
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.