#' Plots the difference for a single observation. Works only with `task.type` "regr".
#'
#' @description This method draws a plot for the data.mean, the observed value
#' and describes the influence of all features/variables for this difference.
#' @param shap.values A shapley object (generated by the shapley function) that contains
#' the shapley.values and other important information about the task and model.
#' @param shap.id (optional) Determones what observation should be taken for plotting, if
#' shap.values have multiple observations.
#' @export
plot.shapley.singleValue = function(shap.values, shap.id=-1) {
if(shap.id != -1 & !shap.id %in% getShapleyIds(shap.values)) {
print(paste("Warning: Could not find _Id <", shap.id, "> in shap.values!"))
shap.id = getShapleyIds(shap.values)[1]
print(paste("First observation with _Id <", shap.id, "> is used from given shap.values."))
at = which(getShapleyValues(shap.values)$"_Id" == shap.id)
} else if(dim(getShapleyValues(shap.values))[1] == 1) {
at = 1
} else if(shap.id %in% getShapleyIds(shap.values)) {
at = which(getShapleyValues(shap.values)$"_Id" == shap.id)
} else {
print("Warning: shap.values contains too many observations..")
shap.id = getShapleyIds(shap.values)[1]
at = 1
print(paste("First observation with _Id <", shap.id, "> is used from given shap.values."))
}
data.mean = getShapleyDataMean(shap.values)
shap.values = getShapleySubsetByResponseClass(shap.values)
data = getShapleyValues(shap.values)[at, getShapleyFeatureNames(shap.values)]
points = compute.shapley.positions(data, data.mean)
plot = ggplot(points, aes(x = values, y = 0)) +
coord_cartesian(ylim = c(-.4, .4)) +
scale_colour_gradient2(low = "#832424FF", high = "#3A3A98FF", mid = "lightgrey", midpoint = data.mean) +
geom_line(aes(colour = values), size = 30) +
geom_label(aes(label = names), angle = 70, nudge_y = rep(c(.1, -.1), times = nrow(points))[1:nrow(points)]) +
geom_point(aes(x = getShapleyPredictionResponse(shap.values)[at], y = 0.1), colour = "black", size = 3) +
theme(axis.title.y = element_blank(),
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
legend.position = "none")
return(plot)
}
#' Plots a graph that shows the expected values, observed values and their difference.
#'
#' @description This method draws a plot that shows the mean of all observations, the observed
#' values and the estimated values over multiple observcations.
#' @param shap.values A shapley object (generated by shapley(...)) that contains
#' the shapley.values and other important information about the task and model.
#' @export
plot.shapley.multipleValues = function(shap.values) {
values = getShapleyValues(shap.values)[,getShapleyFeatureNames(shap.values)]
data.mean = getShapleyDataMean(shap.values)
data.names = c("response.plus", "response.minus", "position", "color")
data = data.frame(matrix(data = 0, nrow = nrow(values), ncol = length(data.names)))
names(data) = data.names
data$response.plus = rowSums(apply(values, 1:2, FUN = function(x) {max(0, x)}))
data$response.minus = rowSums(apply(values, 1:2, FUN = function(x) {min(0, x)}))
data$position = as.numeric(getShapleyIds(shap.values))
data$color = ifelse(data$response.plus < abs(data$response.minus), "red", "green")
ggplot() +
geom_line(data = data, aes(x = position, y = data.mean, colour = "data mean")) +
geom_line(data = data, aes(x = position, y = data.mean + response.plus,
colour = "positive effects")) +
geom_line(data = data, aes(x = position, y = data.mean + response.minus,
colour = "negative effects")) +
geom_ribbon(data = data, aes(x = position, ymax = data.mean,
ymin = data.mean + rowSums(values)), fill = "blue", alpha = .2)
}
#' Calculates the positions of the features influence for plot.singleValue.
#'
#' @description Orders the values by their sign and value, shifts them and returns
#' them as a vector.
#' @param points A vector of shapley.values for a single row
#' @param shift data.mean
compute.shapley.positions = function(points, shift = 0) {
points.minus = sort(points[which(points < 0)])
points.plus = sort(points[which(points >= 0)], decreasing = TRUE)
points.labels = c(names(rev(points.minus)), "0", names(points.plus))
positions = sort(c(cumsum(t(points.minus)), 0, cumsum(t(points.plus))))
result = data.frame(positions + shift)
names(result) = c("values")
result$names = points.labels
result$align = ifelse(result$values > shift,"right", "left")
return(result)
}
#' Plots a graph that shows the effect of several features over multiple observations.
#' @description This method draws a plot for, the observed value and describes
#' the influence of the selected features.
#' @param shap.values A shapley object (generated by shapley(...)) that contains
#' the shapley.values and other important information about the task and model.
#' @param features A vector of the interesting feature names.
#' @export
plot.shapley.multipleFeatures = function(shap.values, features = c("crim", "lstat")) {
features.values = getShapleyValues(shap.values)[,features]
features.numbers = ncol(features.values)
data = data.frame(matrix(data = 0, nrow = nrow(getShapleyValues(shap.values)), ncol = 1 + features.numbers))
names(data) = c(names(features.values), "position")
data[,names(features.values)] = features.values
data$position = as.numeric(getShapleyIds(shap.values))
plot.data = melt(data, id.vars = "position")
plot = ggplot(plot.data) +
geom_line(aes(x = position, y = value, group = variable, color = variable))
return(plot)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.