#' plot shap
#' @param xgb_model xgboost object
#' @param mat_train train data as a matrix
#' @param top_n top n most influential variables to print
plot_shap <- function(xgb_model, mat_train, top_n = 7) {
## Prepare data for top N variables
shap_result <- shap_score_rank(xgb_model = xgb_model,
X_train = mat_train,
shap_approx = F)
shap_long <- shap_prep(shap = shap_result,
X_train = mat_train,
top_n = top_n)
## Plot shap overall metrics
plot_shap_summary(data_long = shap_long)
}
#' return matrix of shap score and mean ranked score list
#' @param xgb_model xgboost object
#' @param shap_approx logical
#' @param X_train train data as a matrix
#' @export
shap_score_rank <- function(xgb_model = xgb_mod, shap_approx = TRUE,
X_train = mat_train){
require(xgboost)
require(data.table)
shap_contrib <- predict(xgb_model, X_train,
predcontrib = TRUE, approxcontrib = shap_approx)
shap_contrib <- as.data.table(shap_contrib)
shap_contrib[,BIAS:=NULL]
mean_shap_score <- colMeans(abs(shap_contrib))[order(colMeans(abs(shap_contrib)), decreasing = T)]
return(list(shap_score = shap_contrib,
mean_shap_score = (mean_shap_score)))
}
#' a function to standardize feature values into same range
#' @param x vector
#' @export
std1 <- function(x){
return ((x - min(x, na.rm = T))/(max(x, na.rm = T) - min(x, na.rm = T)))
}
#' prep shap data
#' @param shap shap_result
#' @param X_train train data as a matrix
#' @param top_n top n variables to print in plot
#' @export
shap_prep <- function(shap = shap_result, X_train = mat_train, top_n){
require(ggforce)
# descending order
if (missing(top_n)) {
top_n <- dim(X_train)[2] # by default, use all features
}
if (!top_n%in%c(1:dim(X_train)[2])) {
top_n <- dim(X_train)[2]
}
require(data.table)
shap_score_sub <- as.data.table(shap$shap_score)
shap_score_sub <- shap_score_sub[, names(shap$mean_shap_score)[1:top_n], with = F]
shap_score_long <- melt.data.table(shap_score_sub, measure.vars = colnames(shap_score_sub))
# feature values: the values in the original dataset
fv_sub <- as.data.table(X_train)[, names(shap$mean_shap_score)[1:top_n], with = F]
# standardize feature values
fv_sub_long <- melt.data.table(fv_sub, measure.vars = colnames(fv_sub))
fv_sub_long[, stdfvalue := std1(value), by = "variable"]
#fv_sub_long[, stdfvalue := (value - mean(value)) / sd(value), by = "variable"]
# SHAP value: value
# raw feature value: rfvalue;
# standarized: stdfvalue
names(fv_sub_long) <- c("variable", "rfvalue", "stdfvalue" )
shap_long2 <- cbind(shap_score_long, fv_sub_long[,c('rfvalue','stdfvalue')])
shap_long2[, mean_value := mean(abs(value)), by = variable]
setkey(shap_long2, variable)
return(shap_long2)
}
#' Plot shap summary
#' @param data_long data
#' @export
plot_shap_summary <- function(data_long){
x_bound <- max(abs(data_long$value))
require('ggforce') # for `geom_sina`
plot1 <- ggplot(data = data_long)+
coord_flip() +
# sina plot:
geom_sina(aes(x = variable, y = value, color = stdfvalue)) +
# print the mean absolute value:
geom_text(data = unique(data_long[, c("variable", "mean_value"), with = F]),
aes(x = variable, y=-Inf, label = sprintf("%.3f", mean_value)),
size = 3, alpha = 0.7,
hjust = -0.2,
fontface = "bold") + # bold
# # add a "SHAP" bar notation
# annotate("text", x = -Inf, y = -Inf, vjust = -0.2, hjust = 0, size = 3,
# label = expression(group("|", bar(SHAP), "|"))) +
scale_color_gradient(low="#FFCC33", high="#6600CC",
breaks=c(0,1), labels=c("Low","High")) +
theme_bw() +
theme(axis.line.y = element_blank(), axis.ticks.y = element_blank(), # remove axis line
legend.position="bottom") +
geom_hline(yintercept = 0) + # the vertical line
scale_y_continuous(limits = c(-x_bound, x_bound)) +
# reverse the order of features
scale_x_discrete(limits = rev(levels(data_long$variable))
) +
labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value")
return(plot1)
}
#' Importance plot
#' @param shap_result shap_result
#' @param top_n top n variables to print in plot
#' @export
var_importance <- function(shap_result, top_n=10)
{
var_importance=tibble(var=names(shap_result$mean_shap_score), importance=shap_result$mean_shap_score)
var_importance=var_importance[1:top_n,]
ggplot(var_importance, aes(x=reorder(var,importance), y=importance)) +
geom_bar(stat = "identity") +
coord_flip() +
theme_light() +
theme(axis.title.y=element_blank())
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.