R/plot_shap.R

Defines functions var_importance plot_shap_summary shap_prep std1 shap_score_rank plot_shap

Documented in plot_shap plot_shap_summary shap_prep shap_score_rank std1 var_importance

#' plot shap
#' @param xgb_model xgboost object
#' @param mat_train train data as a matrix
#' @param top_n top n most influential variables to print
plot_shap <- function(xgb_model, mat_train, top_n = 7) {

  ## Prepare data for top N variables

  shap_result <- shap_score_rank(xgb_model = xgb_model,
                                 X_train = mat_train,
                                 shap_approx = F)

  shap_long <- shap_prep(shap = shap_result,
                         X_train = mat_train,
                         top_n = top_n)

  ## Plot shap overall metrics
  plot_shap_summary(data_long = shap_long)

}

#' return matrix of shap score and mean ranked score list
#' @param xgb_model xgboost object
#' @param shap_approx logical
#' @param X_train train data as a matrix
#' @export

shap_score_rank <- function(xgb_model = xgb_mod, shap_approx = TRUE,
                            X_train = mat_train){
  require(xgboost)
  require(data.table)
  shap_contrib <- predict(xgb_model, X_train,
                          predcontrib = TRUE, approxcontrib = shap_approx)
  shap_contrib <- as.data.table(shap_contrib)
  shap_contrib[,BIAS:=NULL]

  mean_shap_score <- colMeans(abs(shap_contrib))[order(colMeans(abs(shap_contrib)), decreasing = T)]
  return(list(shap_score = shap_contrib,
              mean_shap_score = (mean_shap_score)))
}

#' a function to standardize feature values into same range
#' @param x vector
#' @export
std1 <- function(x){
  return ((x - min(x, na.rm = T))/(max(x, na.rm = T) - min(x, na.rm = T)))
}


#' prep shap data
#' @param shap shap_result
#' @param X_train train data as a matrix
#' @param top_n top n variables to print in plot
#' @export

shap_prep <- function(shap = shap_result, X_train = mat_train, top_n){
  require(ggforce)
  # descending order
  if (missing(top_n)) {
    top_n <- dim(X_train)[2] # by default, use all features
  }

  if (!top_n%in%c(1:dim(X_train)[2])) {
    top_n <- dim(X_train)[2]
  }

  require(data.table)
  shap_score_sub <- as.data.table(shap$shap_score)
  shap_score_sub <- shap_score_sub[, names(shap$mean_shap_score)[1:top_n], with = F]
  shap_score_long <- melt.data.table(shap_score_sub, measure.vars = colnames(shap_score_sub))

  # feature values: the values in the original dataset
  fv_sub <- as.data.table(X_train)[, names(shap$mean_shap_score)[1:top_n], with = F]
  # standardize feature values
  fv_sub_long <- melt.data.table(fv_sub, measure.vars = colnames(fv_sub))
  fv_sub_long[, stdfvalue := std1(value), by = "variable"]

  #fv_sub_long[, stdfvalue := (value - mean(value)) / sd(value), by = "variable"]
  # SHAP value: value
  # raw feature value: rfvalue;
  # standarized: stdfvalue
  names(fv_sub_long) <- c("variable", "rfvalue", "stdfvalue" )
  shap_long2 <- cbind(shap_score_long, fv_sub_long[,c('rfvalue','stdfvalue')])
  shap_long2[, mean_value := mean(abs(value)), by = variable]
  setkey(shap_long2, variable)
  return(shap_long2)
}

#' Plot shap summary
#' @param data_long data
#' @export

plot_shap_summary <- function(data_long){
  x_bound <- max(abs(data_long$value))
  require('ggforce') # for `geom_sina`
  plot1 <- ggplot(data = data_long)+
    coord_flip() +
    # sina plot:
    geom_sina(aes(x = variable, y = value, color = stdfvalue)) +
    # print the mean absolute value:
    geom_text(data = unique(data_long[, c("variable", "mean_value"), with = F]),
              aes(x = variable, y=-Inf, label = sprintf("%.3f", mean_value)),
              size = 3, alpha = 0.7,
              hjust = -0.2,
              fontface = "bold") + # bold
    # # add a "SHAP" bar notation
    # annotate("text", x = -Inf, y = -Inf, vjust = -0.2, hjust = 0, size = 3,
    #          label = expression(group("|", bar(SHAP), "|"))) +
    scale_color_gradient(low="#FFCC33", high="#6600CC",
                         breaks=c(0,1), labels=c("Low","High")) +
    theme_bw() +
    theme(axis.line.y = element_blank(), axis.ticks.y = element_blank(), # remove axis line
          legend.position="bottom") +
    geom_hline(yintercept = 0) + # the vertical line
    scale_y_continuous(limits = c(-x_bound, x_bound)) +
    # reverse the order of features
    scale_x_discrete(limits = rev(levels(data_long$variable))
    ) +
    labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value")
  return(plot1)
}

#' Importance plot
#' @param shap_result shap_result
#' @param top_n top n variables to print in plot
#' @export
var_importance <- function(shap_result, top_n=10)
{
  var_importance=tibble(var=names(shap_result$mean_shap_score), importance=shap_result$mean_shap_score)

  var_importance=var_importance[1:top_n,]

  ggplot(var_importance, aes(x=reorder(var,importance), y=importance)) +
    geom_bar(stat = "identity") +
    coord_flip() +
    theme_light() +
    theme(axis.title.y=element_blank())
}
kristian-bak/kb.modelling documentation built on Dec. 21, 2021, 7:46 a.m.