knitr::opts_chunk$set(echo = TRUE)
In this short tutorial, we provide the code that reproduces the results of the application section of our article entitled "Visualizing the Feature Importance for Black Box Models".
We used batchtools
to run our experiments.
The files application_pi_simulation.R
, application_shapley_simulation.R
and application_importance_realdata.R
contain the batchtools
code to reproduce the expermients and can be found in this directory.
The directory also includes the results of both files in an .Rds
file which is used in the code below to produce the figures and tables.
# load required packages library(data.table) library(ggplot2) library(gridExtra) library(ggExtra) library(xtable) library(knitr) library(featureImportance) source("helper/functions.R")
res = readRDS("application_pi_simulation.Rds") imp.global = rbindlist(lapply(res, function(x) { getImpTable(x, sort = FALSE) })) # compute conditional PFI based on feature V3 = 0 imp0 = rbindlist(lapply(res, function(x) { use0 = x[features == "V3" & feature.value == 0, unique(replace.id)] getImpTable(x, obs.id = use0, sort = FALSE) })) # compute conditional PFI based on feature V3 = 1 imp1 = rbindlist(lapply(res, function(x) { use1 = x[features == "V3" & feature.value == 1, unique(replace.id)] getImpTable(x, obs.id = use1, sort = FALSE) })) tab = rbind( imp.global[, lapply(.SD, pasteMeanSd)], imp0[, lapply(.SD, pasteMeanSd)], imp1[, lapply(.SD, pasteMeanSd)] ) kable(tab)
pfi = res[[1]] # Get index of observations for V3 = 0 and V3 = 1 use0 = pfi[features == "V3" & feature.value == 0, unique(replace.id)] use1 = pfi[features == "V3" & feature.value == 1, unique(replace.id)] vars = c("V1", "V2") pi = lapply(vars, function(var) { pi.cond = conditionalPFI(pfi, var, use0, use1, group.var = "V3") ici = subset(pfi, features == var) ici[, V3 := as.factor(as.numeric(row.id %in% use1))] ind = gsub("[[:alpha:]]", "", var) plotImportance(pfi, feat = var, mid = "mse", hline = FALSE) + geom_line(data = pi.cond, aes(color = V3)) + geom_point(data = pi.cond, aes(color = V3)) + labs(x = bquote(X[.(ind)]), color = bquote(X[3])) + ylim(c(-10, 300)) }) marrangeGrob(pi, nrow = 1, ncol = 2)
res = readRDS("application_shapley_simulation.Rds") # use shorter learner name res[, learner := factor(gsub("regr.", "", learner))] # compute ratio of the importance values w.r.t feature "V3" res[, ratio := mse/mse[feature == "V3"], by = c("method", "learner", "repl")] ### Plot simulation results of all 500 repititions shap = subset(res, method %in% c("pfi.diff", "pfi.ratio", "shapley") & feature != "V3") new.names = setNames(expression(X[1]/X[3], X[2]/X[3]), c("V1", "V2")) pp = ggplot(data = shap, aes(x = feature, y = ratio)) + geom_boxplot(aes(fill = method), lwd = 0.2, outlier.size = 0.8) + facet_grid(. ~ learner, scales = "free") + scale_fill_grey(labels = c("PFI (Diff.)", "PFI (Ratio)", "SFIMP"), start = 0.4, end = 0.9) + scale_x_discrete(labels = new.names) + labs(title = "(b) Simulation with 500 repetitions", x = "Features involved to compute the ratio", y = "Value of the ratio") ### Plot example of an individual repetition (2nd replication) shap2 = subset(res, repl == 2 & method %in% c("shapley", "geP")) # reorder features for plotting feat.order = c("V3", "V2", "V1", "geP") shap2$feature = factor(shap2$feature, levels = feat.order) # change sign shap2[, mse := round(ifelse(method == "geP", mse, -mse), 2)] # add column containing proportion of explained importance shap2[, perc := ifelse(feature == "geP", NA, mse/sum(mse[feature != "geP"])), by = "learner"] # add column containing drop in MSE + proportion of explained importance shap2[, lab := ifelse(feature == "geP", mse, paste0(mse, " (", round(perc*100, 0), "%)"))] col = c(gray.colors(3, start = 0.4, end = 0.9), hcl(h = 195, l = 65, c = 100)) col = setNames(col, feat.order) legend = c("V1" = bquote(phi[1]), "V2" = bquote(phi[2]), "V3" = bquote(phi[3]), "geP" = bquote(widehat(GE)[P])) pp2 = ggplot(shap2, aes(x = learner, y = mse, fill = feature)) + geom_bar(stat = "identity", colour = "white", pos = "stack") + geom_text(aes(label = lab), position = position_stack(vjust = 0.5), size = 3) + coord_flip() + scale_fill_manual(values = col, name = " performance \n explained by", labels = legend) + labs(title = "(a) Comparing the model performance and SFIMP values across different models", x = "", y = "performance (MSE)") grid.arrange(pp2, pp, heights = c(3, 5))
pfi = readRDS("application_importance_realdata.Rds") # get index for LSTAT <= 10 in order to keep those observations pi.ind = unique(pfi[features == "LSTAT" & feature.value <= 10, replace.id]) # compute integral of each ICI curve and select observations with positive ICI integral ici = subset(pfi, features == "LSTAT") ici.area = ici[, lapply(.SD, mean, na.rm = TRUE), .SDcols = "mse", by = "row.id"] ici.ind = which(ici.area$mse > 0) # produce table imp = getImpTable(pfi) imp.pi = getImpTable(pfi, pi.ind) imp.ici = getImpTable(pfi, ici.ind) kable(rbind(imp, imp.pi, imp.ici))
features = c("LSTAT", "RM") pp = lapply(features, function(feat) { ici = subset(pfi, features == feat) ici.area = ici[, lapply(.SD, mean, na.rm = TRUE), .SDcols = "mse", by = "row.id"] ind = c(which.min(ici.area$mse), which.max(ici.area$mse)) ici.obs = subset(ici, row.id %in% ici.area$row.id[ind]) # PI plot pi.plot = plotImportance(pfi, feat, "mse") + theme(legend.position = "none") # ICI plot ici.plot = plotImportance(pfi, feat, "mse", individual = TRUE, grid.points = FALSE, hline = FALSE) + geom_line(data = ici.obs, aes(color = factor(row.id), group = row.id)) + theme(legend.position = "none") list(ggMarginal(pi.plot, type = "histogram", fill = "transparent", margins = "x"), ici.plot) }) pp = unlist(pp, recursive = FALSE) marrangeGrob(pp, nrow = 2, ncol = 2)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.