knitr::opts_chunk$set(fig.width = 7.15, fig.height = 5) knitr::opts_knit$set(fig.width = 7.15, fig.height = 5)
The goal of this vignette is to show how the stochastic Shapley values produced from shapFlex
are correlated
with the non-stochastic Shapley values computed in the Python shap package using the
implentation discussed here.
Treating the tree-based Shapley values from shap
as an approximation of the ground truth, we would
like to see if the sampling based method in shapFlex
can reproduce these values within sampling variability.
We'll use catboost's R
package which has a port of shap
found in the
catboost.get_feature_importance()
function.
While shap
should be the preferred method when modeling with boosted trees, this vignette demonstrates
that the more generic shapFlex
implementation can be applied to all classes of ML prediction models, boosted
tree models included.
library(devtools) # Install catboost which is not available on CRAN (Windows link below). devtools::install_url('https://github.com/catboost/catboost/releases/download/v0.20/catboost-R-Windows-0.20.tgz', INSTALL_opts = c("--no-multiarch")) library(catboost) # version 0.20
library(shapFlex) library(dplyr) library(tidyr) library(ggplot2) data("data_adult", package = "shapFlex") data <- data_adult outcome_name <- "income" outcome_col <- which(names(data) == outcome_name)
cat_features <- which(unlist(Map(is.factor, data[, -outcome_col]))) - 1 data_pool <- catboost.load_pool(data = data[, -outcome_col], label = as.vector(as.numeric(data[, outcome_col])) - 1, cat_features = cat_features) set.seed(224) model_catboost <- catboost.train(data_pool, NULL, params = list(loss_function = 'CrossEntropy', iterations = 30, logging_level = "Silent"))
For shapFlex
, the required user-defined prediction function takes 2 positional arguments
and returns a 1-column data.frame
.
Note the creation of the catboost
-specific data format inside this function.
predict_function <- function(model, data) { data_pool <- catboost.load_pool(data = data, cat_features = cat_features) # Predictions and Shapley explanations will be in log-odds space. data_pred <- data.frame("y_pred" = catboost.predict(model, data_pool)) return(data_pred) }
explain <- data[1:300, -outcome_col] # Compute Shapley feature-level predictions for 300 instances. reference <- data[, -outcome_col] # An optional reference population to compute the baseline prediction. sample_size <- 100 # Number of Monte Carlo samples. set.seed(224) data_shap <- shapFlex::shapFlex(explain = explain, reference = reference, model = model_catboost, predict_function = predict_function, sample_size = sample_size)
data_pool <- catboost.load_pool(data = data[1:300, -outcome_col], label = as.vector(as.numeric(data[1:300, outcome_col])) - 1, cat_features = cat_features) data_shap_catboost <- catboost.get_feature_importance(model_catboost, pool = data_pool, type = "ShapValues") data_shap_catboost <- data.frame(data_shap_catboost[, -ncol(data_shap_catboost)]) # Remove the intercept column. data_shap_catboost$index <- 1:nrow(data_shap_catboost) data_shap_catboost <- tidyr::gather(data_shap_catboost, key = "feature_name", value = "shap_effect_catboost", -index)
data_all <- dplyr::inner_join(data_shap, data_shap_catboost, by = c("index", "feature_name"))
write.csv(data_all, "data_shap.csv", row.names = FALSE)
data_all <- read.csv("data_shap.csv", stringsAsFactors = FALSE)
data_cor <- data_all %>% dplyr::group_by(feature_name) %>% dplyr::summarise("cor_coef" = round(cor(shap_effect, shap_effect_catboost), 3)) data_cor
p <- ggplot(data_all, aes(shap_effect_catboost, shap_effect)) p <- p + geom_point(alpha = .25) p <- p + geom_abline(color = "red") p <- p + facet_wrap(~ feature_name, scales = "free") p <- p + theme_bw() + xlab("catboost tree-based Shapley values") + ylab("shapFlex stochastic Shapley values") + theme(axis.title = element_text(face = "bold")) p
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.