cluster_profiles: Cluster Ceteris Paribus Profiles

View source: R/cluster_profiles.R

cluster_profilesR Documentation

Cluster Ceteris Paribus Profiles

Description

This function calculates aggregates of ceteris paribus profiles based on hierarchical clustering.

Usage

cluster_profiles(
  x,
  ...,
  aggregate_function = mean,
  variable_type = "numerical",
  center = FALSE,
  k = 3,
  variables = NULL
)

Arguments

x

a ceteris paribus explainer produced with function ceteris_paribus()

...

other explainers that shall be plotted together

aggregate_function

a function for profile aggregation. By default it's mean

variable_type

a character. If numerical then only numerical variables will be computed. If categorical then only categorical variables will be computed.

center

shall profiles be centered before clustering

k

number of clusters for the hclust function

variables

if not NULL then only variables will be presented

Details

Find more detailes in the Clustering Profiles Chapter.

Value

an object of the class aggregated_profiles_explainer

References

Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. https://ema.drwhy.ai/

Examples

library("DALEX")
library("ingredients")

selected_passangers <- select_sample(titanic_imputed, n = 100)
model_titanic_glm <- glm(survived ~ gender + age + fare,
                         data = titanic_imputed, family = "binomial")

explain_titanic_glm <- explain(model_titanic_glm,
                               data = titanic_imputed[,-8],
                               y = titanic_imputed[,8])

cp_rf <- ceteris_paribus(explain_titanic_glm, selected_passangers)
clust_rf <- cluster_profiles(cp_rf, k = 3, variables = "age")
plot(clust_rf)


library("ranger")
model_titanic_rf <- ranger(survived ~., data = titanic_imputed, probability = TRUE)

explain_titanic_rf <- explain(model_titanic_rf,
                              data = titanic_imputed[,-8],
                              y = titanic_imputed[,8],
                              label = "ranger forest",
                              verbose = FALSE)

cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers)
cp_rf

pdp_rf <- aggregate_profiles(cp_rf, variables = "age")
head(pdp_rf)
clust_rf <- cluster_profiles(cp_rf, k = 3, variables = "age")
head(clust_rf)

plot(clust_rf, color = "_label_") +
  show_aggregated_profiles(pdp_rf, color = "black", size = 3)

plot(cp_rf, color = "grey", variables = "age") +
  show_aggregated_profiles(clust_rf, color = "_label_", size = 2)

clust_rf <- cluster_profiles(cp_rf, k = 3, center = TRUE, variables = "age")
head(clust_rf)


ingredients documentation built on Jan. 15, 2023, 5:09 p.m.