knitr::opts_chunk$set(
  collapse = FALSE,
  comment = "#>",
  warning = FALSE,
  message = FALSE
)

Data for HR example

In the following vignette, we will walk through a multilabel classification example with DALEX. The purpose of this tutorial is that for some of DALEX functionalities binary classification is a default one, and therefore we need to put some self-made code to work here. All of the examples will be performed with HR dataset that is available in DALEX, it's target column is status with three-level factor. For all cases our model will be ranger.

library("DALEX")
old_theme = set_theme_dalex("ema") 

data(HR)
head(HR)

Creation of model and explainer

Ok, now it is time to create a model.

library("ranger")
model_HR_ranger <- ranger(status~.,  data = HR, probability = TRUE, num.trees = 50)
model_HR_ranger
library("DALEX")
explain_HR_ranger <- explain(model_HR_ranger,
                              data = HR[,-6],
                              y = HR$status,
                              label = "Ranger Multilabel Classification",
                              colorize = FALSE)

The sixth column, that we have omitted during the creation of the explainer, stands for the target column (status). It is good practice not to put it in data. Keep in mind that the default yhat function for ranger, and for any other package that is supported by DALEX, enforces probability output. Therefore residuals cannot be standard $y - \hat{y}$. Since DALEX 1.2.2 in the case of multiclass classification one minus probability of the TRUE class is a standard residual function.

Model Parts

In order to use model_parts() (former variable_importance()) function it is necessary to switch default loss_function argument to one that handle multiple classes. DALEX has implemented one function like that and it is called loss_cross_entropy(). To use it, y parameter passed to explain function should have exactly the same format as the target vector used for the training process (ie. the same number of levels and names of those levels).

Also, we need probability outputs so there is no need to change the default predict_function parameter.

library("DALEX")
explain_HR_ranger_new_y <- explain(model_HR_ranger,
                              data = HR[,-6],
                              y = HR$status,
                              label = "Ranger Multilabel Classification",
                              colorize = FALSE)

And now we can use model_parts()

mp <- model_parts(explain_HR_ranger_new_y, loss_function = loss_cross_entropy)
plot(mp)

As we see above, we can enjoy perfectly fine variable importance plot.

Model Profile

There is no need for tricks in order to use model_profile() (former variable_effect()). Our target will be one-hot-encoded, and all of the explanations will be performed for each of class separately.

partial_dependency

mp_p <- model_profile(explain_HR_ranger, variables = "salary", type = "partial")
mp_p$color <- "_label_"
plot(mp_p)

accumulated_dependency

mp_a <- model_profile(explain_HR_ranger, variables = "salary", type = "accumulated")
mp_a$color = "_label_"
plot(mp_a)

Instance level explanations

As above, predict_parts() (former variable_attribution()) works perfectly fine with multilabel classification and default explainer. Just like before, our target will be split into variables standing for each factor level and computations will be performed then.

break_down

bd <- predict_parts(explain_HR_ranger, HR[1,], type = "break_down")
plot(bd)

shap

shap <- predict_parts(explain_HR_ranger, HR[1,], type = "shap")
plot(shap)

model_performance and predict_diagnostics

The description of those two functions is merged into one paragraph because they require the same action to get them to work with multilabel classification. The most important thing here is to realize that both functions are based on residuals. Since DALEX 1.2.2, explain function recognizes if a model is a multiclass classification task and uses a dedicated residual function as default.

Model Performance

(mp <- model_performance(explain_HR_ranger))
plot(mp)

Predict diagnostics

pd_all <- predict_diagnostics(explain_HR_ranger, HR[1,])
plot(pd_all)
pd_salary <- predict_diagnostics(explain_HR_ranger, HR[1,], variables = "salary")
plot(pd_salary)

Session info

sessionInfo()


ModelOriented/DALEX documentation built on Feb. 29, 2024, 6:55 a.m.