R/plot_shapley.R

#' Plots the difference for a single observation. Works only with `task.type` "regr".
#'
#' @description This method draws a plot for the data.mean, the observed value
#'   and describes the influence of all features/variables for this difference.
#' @param shap.values A shapley object (generated by the shapley function) that contains
#'   the shapley.values and other important information about the task and model.
#' @param shap.id (optional) Determones what observation should be taken for plotting, if
#'   shap.values have multiple observations.
#' @export
plot.shapley.singleValue = function(shap.values, shap.id=-1) {
  if(shap.id != -1 & !shap.id %in% getShapleyIds(shap.values)) {
    print(paste("Warning: Could not find _Id <", shap.id, "> in shap.values!"))
    shap.id = getShapleyIds(shap.values)[1]
    print(paste("First observation with _Id <", shap.id, "> is used from given shap.values."))
    at = which(getShapleyValues(shap.values)$"_Id" == shap.id)
  } else if(dim(getShapleyValues(shap.values))[1] == 1) {
    at = 1
  } else if(shap.id %in% getShapleyIds(shap.values)) {
    at = which(getShapleyValues(shap.values)$"_Id" == shap.id)
  } else {
    print("Warning: shap.values contains too many observations..")
    shap.id = getShapleyIds(shap.values)[1]
    at = 1
    print(paste("First observation with _Id <", shap.id, "> is used from given shap.values."))
  }

  data.mean = getShapleyDataMean(shap.values)
  shap.values = getShapleySubsetByResponseClass(shap.values)
  data = getShapleyValues(shap.values)[at, getShapleyFeatureNames(shap.values)]
  points = compute.shapley.positions(data, data.mean)

  plot = ggplot(points, aes(x = values, y = 0)) +
    coord_cartesian(ylim = c(-.4, .4)) +
    scale_colour_gradient2(low = "#832424FF", high = "#3A3A98FF", mid = "lightgrey", midpoint = data.mean) +
    geom_line(aes(colour = values), size = 30) +
    geom_label(aes(label = names), angle = 70, nudge_y = rep(c(.1, -.1), times = nrow(points))[1:nrow(points)]) +
    geom_point(aes(x = getShapleyPredictionResponse(shap.values)[at], y = 0.1), colour = "black", size = 3) +
    theme(axis.title.y = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank(),
          legend.position = "none")

  return(plot)
}

#' Plots a graph that shows the expected values, observed values and their difference.
#'
#' @description This method draws a plot that shows the mean of all observations, the observed
#'   values and the estimated values over multiple observcations.
#' @param shap.values A shapley object (generated by shapley(...)) that contains
#' the shapley.values and other important information about the task and model.
#' @export
plot.shapley.multipleValues = function(shap.values) {
  values = getShapleyValues(shap.values)[,getShapleyFeatureNames(shap.values)]
  data.mean = getShapleyDataMean(shap.values)

  data.names = c("response.plus", "response.minus", "position", "color")
  data = data.frame(matrix(data = 0, nrow = nrow(values), ncol = length(data.names)))
  names(data) = data.names
  data$response.plus = rowSums(apply(values, 1:2, FUN = function(x) {max(0, x)}))
  data$response.minus = rowSums(apply(values, 1:2, FUN = function(x) {min(0, x)}))
  data$position = as.numeric(getShapleyIds(shap.values))
  data$color = ifelse(data$response.plus < abs(data$response.minus), "red", "green")

  ggplot() +
    geom_line(data = data, aes(x = position, y = data.mean, colour = "data mean")) +
    geom_line(data = data, aes(x = position, y = data.mean + response.plus,
                               colour = "positive effects")) +
    geom_line(data = data, aes(x = position, y = data.mean + response.minus,
                               colour = "negative effects")) +
    geom_ribbon(data = data, aes(x = position, ymax = data.mean,
                                 ymin = data.mean + rowSums(values)), fill = "blue", alpha = .2)
}

#' Calculates the positions of the features influence for plot.singleValue.
#'
#' @description Orders the values by their sign and value, shifts them and returns
#'   them as a vector.
#' @param points A vector of shapley.values for a single row
#' @param shift data.mean
compute.shapley.positions = function(points, shift = 0) {
  points.minus = sort(points[which(points < 0)])
  points.plus = sort(points[which(points >= 0)], decreasing = TRUE)
  points.labels = c(names(rev(points.minus)), "0", names(points.plus))
  positions = sort(c(cumsum(t(points.minus)), 0, cumsum(t(points.plus))))

  result = data.frame(positions + shift)
  names(result) = c("values")
  result$names = points.labels
  result$align = ifelse(result$values > shift,"right", "left")

  return(result)
}

#' Plots a graph that shows the effect of several features over multiple observations.
#' @description This method draws a plot for, the observed value and describes
#'   the influence of the selected features.
#' @param shap.values A shapley object (generated by shapley(...)) that contains
#' the shapley.values and other important information about the task and model.
#' @param features A vector of the interesting feature names.
#' @export
plot.shapley.multipleFeatures = function(shap.values, features = c("crim", "lstat")) {
  features.values = getShapleyValues(shap.values)[,features]
  features.numbers = ncol(features.values)

  data = data.frame(matrix(data = 0, nrow = nrow(getShapleyValues(shap.values)), ncol = 1 + features.numbers))
  names(data) = c(names(features.values), "position")
  data[,names(features.values)] = features.values
  data$position = as.numeric(getShapleyIds(shap.values))
  plot.data = melt(data, id.vars = "position")

  plot = ggplot(plot.data) +
    geom_line(aes(x = position, y = value, group = variable, color = variable))

  return(plot)
}
redichh/ShapleyR documentation built on May 28, 2019, 7:49 a.m.