Nothing
#' tidy shap
#'
#' plot and summarize shapley values from an xgboost model
#'
#' returns a list with the following entries
#'
#' \describe{
#' \item{\emph{shap_tbl}}{: table of shaply values}
#' \item{\emph{shap_summary}}{: table summarizing shapley values. Includes correlation between shaps and feature values.}
#' \item{\emph{swarmplot}}{: one plot showing the relation between shaps and features}
#' \item{\emph{scatterplots}}{: returns the top 9 most important features as determined by sum of absolute shapley values, as a facetted scatterplot of feature vs shap}
#' }
#'
#' @param model xgboost model
#' @param newdata dataframe similar to model input
#' @param form formula used for model
#' @param ... additional parameters for shapley value
#' @param top_n top n features
#' @param aggregate a character vector. Predictors containing the string will be aggregated, and renamed to that string.
#'
#' @return list
#' @export
tidy_shap <- function(model, newdata, form = NULL, ..., top_n = 12, aggregate = NULL){
value <- sum_abs <- NULL
presenter::get_piped_name() -> model_name
rlang::as_name(rlang::ensym(newdata)) -> data_name
form %>%
f_formula_to_charvec(.data = newdata) -> predictors
newdata %>%
dplyr::select(tidyselect::all_of(predictors)) -> newdata1
newdata1 %>%
as.matrix() -> newdata2
predict(model, newdata = newdata2, predcontrib = TRUE) -> preds
## preds
name <- FEATURE <- SHAP <- BIAS <- TYPE <- NULL
suppressWarnings({
preds %>%
tibble::as_tibble() %>%
dplyr::select(-BIAS) -> preds1
# long shaps
preds1 %>%
dplyr::mutate(TYPE = "SHAP") %>%
tidyr::pivot_longer(cols = -TYPE) %>%
dplyr::bind_rows(
newdata1 %>%
dplyr::select(-tidyselect::any_of(rlang::f_lhs(form))) %>%
tibble::as_tibble() %>%
dplyr::mutate(TYPE = "FEATURE") %>%
tidyr::pivot_longer(cols = -TYPE)
) %>%
dplyr::arrange(name, TYPE) %>%
tidyr::pivot_wider(names_from = TYPE, values_from = value) %>%
tidyr::unnest(c(FEATURE, SHAP)) -> gplottbl
if(!is.null(aggregate)){
agg <- name <- FEATURE <- NULL
gplottbl %>%
dplyr::mutate(agg = stringr::str_extract(name, stringr::str_c(
aggregate, collapse = "|"))) %>%
dplyr::mutate(name = dplyr::coalesce(agg, name)) %>%
dplyr::select(-agg) %>%
dplyr::group_by(name, FEATURE) %>%
dplyr::summarise(SHAP = mean(SHAP), .groups = "drop") -> gplottbl
}
## swarm plot
xgboost::xgb.ggplot.shap.summary(newdata2, preds, model = model, top_n = top_n, ...) -> shaps
new_name <- form %>%
rlang::f_lhs() %>%
as.character() %>%
stringr::str_c(" shaps from model ", model_name, " on dataset ", data_name)
shaps +
ggplot2::labs(title = new_name, color = "normalized feature value", x = "shapley value") +
ggplot2::xlab("shapley value") +
ggplot2::ylab("feature name") -> swarm_plot
## shaps summary
gplottbl %>%
dplyr::group_by(name) %>%
dplyr::summarise(cor = stats::cor(FEATURE, SHAP),
var = stats::var(SHAP),
sum = sum(SHAP),
sum_abs = sum(abs(SHAP))) %>%
dplyr::arrange(dplyr::desc(sum_abs)) -> shaps_sum
## continuous scatterplots
newdata %>%
purrr::map_lgl(~dplyr::n_distinct(.) <= 2) %>%
which() %>%
names -> binaries
shaps_sum %>%
dplyr::pull(name) %>%
setdiff(binaries) %>%
utils::head(9) -> top_9
if(!rlang::is_empty(top_9)){
gplottbl %>%
dplyr::filter(name %in% top_9) %>%
ggplot2::ggplot(ggplot2::aes(x = FEATURE, y = SHAP, color = name)) +
ggplot2::geom_jitter(alpha = .5) +
ggplot2::geom_smooth() +
ggplot2::theme_minimal() +
ggplot2::facet_wrap(~name, scales = "free_x") +
ggplot2::theme(legend.position = "none") -> scatterplots
} else {
scatterplots <- "no continuous vars"
}
})
# binary boxplots
shaps_sum %>%
dplyr::pull(name) %>%
intersect(binaries) %>%
utils::head(9) -> top_9_binary
if(!rlang::is_empty(top_9_binary)){
gplottbl %>%
dplyr::filter(name %in% top_9_binary) %>%
ggplot2::ggplot(ggplot2::aes(x = factor(FEATURE), y = SHAP, color = name)) +
ggplot2::geom_boxplot(alpha = .5) +
ggplot2::theme_minimal() +
ggplot2::xlab("BINARY FEATURE") +
ggplot2::facet_wrap(~name, scales = "free_x") +
ggplot2::theme(legend.position = "none") -> boxplots
} else {
boxplots <- "no binary vars"
}
## combine
list(
shap_tbl = preds1,
shap_summary = shaps_sum,
swarmplot = swarm_plot,
scatterplots = scatterplots,
boxplots = boxplots
) -> shapslist
shapslist
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.