observation_plot: SHAP Observation Plot

Description Usage Arguments Details Value Examples

View source: R/observation_plot.R

Description

This Function plots the given contributions for a single observation, and demonstrates how the model arrived at the prediction for the given observation.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
observation_plot(
  variable_values,
  shap_values,
  expected_value,
  names = NULL,
  num_vars = 10,
  fill_colors = c("#A54657", "#0D3B66"),
  connect_color = "#849698",
  expected_color = "#849698",
  predicted_color = "#EE964B",
  title = "Individual Observation Explanation",
  font_family = "Times New Roman"
)

Arguments

variable_values

A data frame of the values of the variables that caused the given SHAP values, generally will be the same data frame or matrix that was passed to the model for prediction.

shap_values

A data frame of shap values, either returned by mshap() or obtained from the python {shap} module.

expected_value

The expected value of the SHAP explainer, either returned by mshap() or obtained from the python {shap} module.

names

A character vector of variable names, corresponding to the order of the columns in both variable_values and shap_values. If NULL (default), then the column names of the variable_values are taken as names.

num_vars

An integer specifying the number of variables to show in the plot, defaults to the 10 most important.

fill_colors

A character vector of length 2. The first element specifies the fill of a negative SHAP value and the second element specifies the fill of a positive SHAP value.

connect_color

A string specifying the color of the line segment that connects the SHAP value bars

expected_color

A string specifying the color of the line that marks the baseline value, or the expected model output.

predicted_color

A string specifying the color of the line that marks the value predicted by the model.

title

A string specifying the title of the plot.

font_family

A string specifying the font family, defaults to Times New Roman.

Details

This function allows the user to pass a sing row from a data frame of SHAP values and variable values along with an expected model output and it returns a ggplot object displaying a specific map of the effect of Variable value on SHAP value. It is created with {ggplot2}, and the returned value is a {ggplot2} object that can be modified for given themes/colors.

Please note that for the variable_values and shap_values arguments, both of which are data frames, the columns must be in the same order. This is essential in assuring that the variable values and labels are matched to the correct shap values.

Value

A {ggplot2} object

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
if (interactive()) {
library(mshap)
library(ggplot2)

# Generate fake data
set.seed(18)
dat <- data.frame(
  age = runif(1000, min = 0, max = 20),
  prop_domestic = runif(1000),
  model = sample(c(0, 1), 1000, replace = TRUE),
  maintain = rexp(1000, .01) + 200
)
shap <- data.frame(
  age = rexp(1000, 1/dat$age) * (-1)^(rbinom(1000, 1, dat$prop_domestic)),
  prop_domestic = -200 * rnorm(100, dat$prop_domestic, 0.02) + 100,
  model = ifelse(dat$model == 0, rnorm(1000, -50, 30), rnorm(1000, 50, 30)),
  maintain = (rnorm(1000, dat$maintain, 100) - 400) * 0.2
)
expected_value <- 1000

# A Basic sumary plot
summary_plot(
  variable_values = dat,
  shap_values = shap
)

# A Customized summary plot
summary_plot(
  variable_values = dat,
  shap_values = shap,
  legend.position = "bottom",
  names = c("Age", "% Domestic", "Model", "Maintenence Hours"),
  colorscale = c("blue", "purple", "red"),
  font_family = "Arial",
  title = "A Custom Title"
)

# A basic observation plot
observation_plot(
  variable_values = dat[1,],
  shap_values = shap[1,],
  expected_value = expected_value
)

# A Customized Observation plot
observation_plot(
  variable_values = dat[1,],
  shap_values = shap[1,],
  expected_value = expected_value,
  names = c("Age", "% Domestic", "Model", "Maintenence Hours"),
  font_family = "Arial",
  title = "A Custom Title",
  fill_colors = c("red", "blue"),
  connect_color = "black",
  expected_color = "purple",
  predicted_color = "yellow"
)

# Add elements to the returned object
# see vignette("mshap_plots") for more information
observation_plot(
  variable_values = dat[1,],
  shap_values = shap[1,],
  expected_value = expected_value,
  names = c("Age", "% Domestic", "Model", "Maintenence Hours"),
  font_family = "Arial",
  title = "A Custom Title"
) +
  geom_label(
    aes(y = 950, x = 4, label = "This is a really big bar!"),
    color = "#FFFFFF",
    fill = NA
  ) +
  theme(
    plot.background = element_rect(fill = "grey"),
    panel.background = element_rect(fill = "lightyellow")
  )
}

mshap documentation built on June 17, 2021, 9:07 a.m.