aggregate_profiles: Aggregates Ceteris Paribus Profiles

View source: R/aggregate_profiles.R

aggregate_profilesR Documentation

Aggregates Ceteris Paribus Profiles

Description

The function aggregate_profiles() calculates an aggregate of ceteris paribus profiles. It can be: Partial Dependence Profile (average across Ceteris Paribus Profiles), Conditional Dependence Profile (local weighted average across Ceteris Paribus Profiles) or Accumulated Local Dependence Profile (cummulated average local changes in Ceteris Paribus Profiles).

Usage

aggregate_profiles(
  x,
  ...,
  variable_type = "numerical",
  groups = NULL,
  type = "partial",
  variables = NULL,
  span = 0.25,
  center = FALSE
)

Arguments

x

a ceteris paribus explainer produced with function ceteris_paribus()

...

other explainers that shall be calculated together

variable_type

a character. If numerical then only numerical variables will be calculated. If categorical then only categorical variables will be calculated.

groups

a variable name that will be used for grouping. By default NULL which means that no groups shall be calculated

type

either partial/conditional/accumulated for partial dependence, conditional profiles of accumulated local effects

variables

if not NULL then aggregate only for selected variables will be calculated

span

smoothing coefficient, by default 0.25. It's the sd for gaussian kernel

center

by default accumulated profiles start at 0. If center=TRUE, then they are centered around mean prediction, which is calculated on the observations used in ceteris_paribus.

Value

an object of the class aggregated_profiles_explainer

References

Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. https://ema.drwhy.ai/

Examples

library("DALEX")
library("ingredients")
library("ranger")
head(titanic_imputed)

model_titanic_rf <- ranger(survived ~.,  data = titanic_imputed, probability = TRUE)

explain_titanic_rf <- explain(model_titanic_rf,
                              data = titanic_imputed[,-8],
                              y = titanic_imputed[,8],
                              label = "ranger forest",
                              verbose = FALSE)

selected_passangers <- select_sample(titanic_imputed, n = 100)
cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers)
head(cp_rf)

# continuous variable
pdp_rf_p <- aggregate_profiles(cp_rf, variables = "age", type = "partial")
pdp_rf_p$`_label_` <- "RF_partial"
pdp_rf_c <- aggregate_profiles(cp_rf, variables = "age", type = "conditional")
pdp_rf_c$`_label_` <- "RF_conditional"
pdp_rf_a <- aggregate_profiles(cp_rf, variables = "age", type = "accumulated")
pdp_rf_a$`_label_` <- "RF_accumulated"

plot(pdp_rf_p, pdp_rf_c, pdp_rf_a, color = "_label_")

pdp_rf <- aggregate_profiles(cp_rf, variables = "age",
                             groups = "gender")

head(pdp_rf)
plot(cp_rf, variables = "age") +
  show_observations(cp_rf, variables = "age") +
  show_rugs(cp_rf, variables = "age", color = "red") +
  show_aggregated_profiles(pdp_rf, size = 3, color = "_label_")

# categorical variable
pdp_rf_p <- aggregate_profiles(cp_rf, variables = "class",
                               variable_type = "categorical",  type = "partial")
pdp_rf_p$`_label_` <- "RF_partial"
pdp_rf_c <- aggregate_profiles(cp_rf, variables = "class",
                               variable_type = "categorical", type = "conditional")
pdp_rf_c$`_label_` <- "RF_conditional"
pdp_rf_a <- aggregate_profiles(cp_rf, variables = "class",
                               variable_type = "categorical", type = "accumulated")
pdp_rf_a$`_label_` <- "RF_accumulated"
plot(pdp_rf_p, pdp_rf_c, pdp_rf_a, color = "_label_")

# or maybe flipped?
library(ggplot2)
plot(pdp_rf_p, pdp_rf_c, pdp_rf_a, color = "_label_") + coord_flip()

pdp_rf <- aggregate_profiles(cp_rf, variables = "class", variable_type = "categorical",
                             groups = "gender")
head(pdp_rf)
plot(pdp_rf, variables = "class")
# or maybe flipped?
plot(pdp_rf, variables = "class") + coord_flip()




ingredients documentation built on Jan. 15, 2023, 5:09 p.m.