R/tidy_shap.R

Defines functions tidy_shap

Documented in tidy_shap

#' tidy shap
#'
#' plot and summarize shapley values from an xgboost model
#'
#' returns a list with the following entries
#'
#' \describe{
#' \item{\emph{shap_tbl}}{: table of shaply values}
#' \item{\emph{shap_summary}}{: table summarizing shapley values. Includes correlation between shaps and feature values.}
#' \item{\emph{swarmplot}}{: one plot showing the relation between shaps and features}
#' \item{\emph{scatterplots}}{: returns the top 9 most important features as determined by sum of absolute shapley values, as a facetted scatterplot of feature vs shap}
#' }
#'
#' @param model xgboost model
#' @param newdata dataframe similar to model input
#' @param form formula used for model
#' @param ... additional parameters for shapley value
#' @param top_n top n features
#' @param aggregate a character vector. Predictors containing the string will be aggregated, and renamed to that string.
#'
#' @return list
#' @export
tidy_shap <- function(model, newdata, form = NULL, ..., top_n = 12, aggregate = NULL){

  value <- sum_abs <- NULL

  presenter::get_piped_name() -> model_name

  rlang::as_name(rlang::ensym(newdata)) -> data_name

  form %>%
    f_formula_to_charvec(.data = newdata) -> predictors

  newdata %>%
    dplyr::select(tidyselect::all_of(predictors)) -> newdata1


  newdata1 %>%
    as.matrix() -> newdata2

  predict(model, newdata = newdata2, predcontrib = TRUE) -> preds

  ## preds
  name <- FEATURE <- SHAP <- BIAS <- TYPE <- NULL
  suppressWarnings({

    preds %>%
      tibble::as_tibble() %>%
      dplyr::select(-BIAS) -> preds1

    # long shaps

    preds1 %>%
      dplyr::mutate(TYPE = "SHAP") %>%
      tidyr::pivot_longer(cols = -TYPE) %>%
      dplyr::bind_rows(
        newdata1 %>%
          dplyr::select(-tidyselect::any_of(rlang::f_lhs(form))) %>%
          tibble::as_tibble() %>%
          dplyr::mutate(TYPE = "FEATURE") %>%
          tidyr::pivot_longer(cols = -TYPE)
      ) %>%
      dplyr::arrange(name, TYPE) %>%
      tidyr::pivot_wider(names_from = TYPE, values_from = value) %>%
      tidyr::unnest(c(FEATURE, SHAP)) ->  gplottbl


if(!is.null(aggregate)){

  agg <- name <- FEATURE <- NULL

  gplottbl %>%
    dplyr::mutate(agg = stringr::str_extract(name, stringr::str_c(
      aggregate, collapse = "|"))) %>%
    dplyr::mutate(name = dplyr::coalesce(agg, name)) %>%
    dplyr::select(-agg) %>%
    dplyr::group_by(name, FEATURE) %>%
    dplyr::summarise(SHAP = mean(SHAP), .groups = "drop") -> gplottbl
}

 ## swarm plot

   xgboost::xgb.ggplot.shap.summary(newdata2, preds, model = model, top_n = top_n, ...)  -> shaps


   new_name <- form %>%
     rlang::f_lhs() %>%
     as.character() %>%
     stringr::str_c(" shaps from model ", model_name, " on dataset ", data_name)

 shaps +
   ggplot2::labs(title = new_name, color = "normalized feature value", x = "shapley value") +
   ggplot2::xlab("shapley value") +
   ggplot2::ylab("feature name") -> swarm_plot



 ## shaps summary


 gplottbl %>%
   dplyr::group_by(name) %>%
   dplyr::summarise(cor = stats::cor(FEATURE, SHAP),
             var = stats::var(SHAP),
             sum = sum(SHAP),
             sum_abs = sum(abs(SHAP))) %>%
   dplyr::arrange(dplyr::desc(sum_abs)) -> shaps_sum

 ## continuous scatterplots

 newdata %>%
   purrr::map_lgl(~dplyr::n_distinct(.) <= 2) %>%
   which() %>%
   names -> binaries

 shaps_sum %>%
   dplyr::pull(name) %>%
   setdiff(binaries) %>%
   utils::head(9) -> top_9

 if(!rlang::is_empty(top_9)){

 gplottbl %>%
   dplyr::filter(name %in% top_9) %>%
 ggplot2::ggplot(ggplot2::aes(x = FEATURE, y = SHAP, color = name)) +
   ggplot2::geom_jitter(alpha = .5) +
   ggplot2::geom_smooth() +
   ggplot2::theme_minimal() +
   ggplot2::facet_wrap(~name, scales = "free_x") +
   ggplot2::theme(legend.position = "none") -> scatterplots
 } else {
   scatterplots <- "no continuous vars"
 }
  })

# binary boxplots

  shaps_sum %>%
    dplyr::pull(name) %>%
    intersect(binaries) %>%
    utils::head(9) -> top_9_binary


  if(!rlang::is_empty(top_9_binary)){

  gplottbl %>%
    dplyr::filter(name %in% top_9_binary) %>%
    ggplot2::ggplot(ggplot2::aes(x = factor(FEATURE), y = SHAP, color = name)) +
    ggplot2::geom_boxplot(alpha = .5) +
    ggplot2::theme_minimal() +
    ggplot2::xlab("BINARY FEATURE") +
    ggplot2::facet_wrap(~name, scales = "free_x") +
    ggplot2::theme(legend.position = "none") -> boxplots
  } else {
    boxplots <- "no binary vars"
  }
  ## combine

 list(
   shap_tbl = preds1,
   shap_summary = shaps_sum,
   swarmplot = swarm_plot,
   scatterplots = scatterplots,
   boxplots = boxplots
 ) -> shapslist



 shapslist
}

Try the autostats package in your browser

Any scripts or data that you put into this service are public.

autostats documentation built on Nov. 10, 2022, 6:13 p.m.