knitr::opts_chunk$set(
  collapse = FALSE,
  comment = "#>",
  warning = FALSE,
  message = FALSE
)

Data for Titanic survival

Let's see an example for DALEX package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic available in the DALEX package. Note that this data was copied from the stablelearner package.

library("DALEX")
head(titanic_imputed)

Model for Titanic survival

Ok, not it's time to create a model. Let's use the Random Forest model.

# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
                                   fare + sibsp + parch,  data = titanic_imputed,
                           classification = TRUE)
model_titanic_rf

Explainer for Titanic survival

The third step (it's optional but useful) is to create a DALEX explainer for the random forest model.

library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf, 
                      data = titanic_imputed,
                      y = titanic_imputed$survived, 
                      label = "Random Forest v7",
                      colorize = FALSE)

Variable importance plots

Use the variable_importance() explainer to present the importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0.

vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)
plot(vi_rf)

Variable effects

As we see the most important feature is Sex. Next three important features are Pclass, Age and Fare. Let's see the link between model response and these features.

Such univariate relation can be calculated with variable_effect().

Age

Kids 5 years old and younger have a much higher survival probability.

vr_age  <- model_profile(explain_titanic_rf, variables =  "age")
head(vr_age)
plot(vr_age)

Passenger class

Passengers in the first-class have much higher survival probability.

vr_class  <- model_profile(explain_titanic_rf, variables =  "class")
plot(vr_class)

Fare

Very cheap tickets are linked with lower chances.

vr_fare  <- variable_profile(explain_titanic_rf, variables =  "fare")
plot(vr_fare)

Embarked

Passengers that embarked from C have the highest survival.

vr_embarked  <- model_profile(explain_titanic_rf, variables =  "embarked")
plot(vr_embarked)

Instance level explanations

Let's see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.

new_passanger <- data.frame(
  class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
  gender = factor("male", levels = c("female", "male")),
  age = 8,
  sibsp = 0,
  parch = 0,
  fare = 72,
  embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)

sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)

It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite being a male.

More models

Let's train more models for survival.

Logistic regression

library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
                   parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed, 
                   y = titanic_imputed$survived, 
                   predict_function = function(m,x) 
                            predict(m, x, type = "fitted"),
                   label = "Logistic regression")

Generalized Boosted Models (GBM)

library("gbm")
model_titanic_gbm <- gbm(survived  ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, 
                     n.trees = 15000)
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       predict_function = function(m,x) 
                            predict(m, x, n.trees = 15000, type = "response"),
                       label = "Generalized Boosted Models",
                       colorize = FALSE)

Support Vector Machines (SVM)

library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, 
             type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       label = "Support Vector Machines",
                       colorize = FALSE)

k-Nearest Neighbors (kNN)

library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       predict_function = function(m,x) predict(m, x)[,2],
                       label = "k-Nearest Neighbors",
                       colorize = FALSE)

Variable performance

vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)

plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)

Single variable

vr_age_rf   <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr  <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm  <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm  <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn  <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles, 
     vr_age_lmr$agr_profiles, 
     vr_age_gbm$agr_profiles, 
     vr_age_svm$agr_profiles, 
     vr_age_knn$agr_profiles)

Instance level explanations

sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)
sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)
sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)
sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)

Session info

sessionInfo()


ModelOriented/DALEX documentation built on Feb. 29, 2024, 6:55 a.m.