knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

Here we will use the HR churn data (https://www.kaggle.com/) to present the breakDown package for randomForest models.

The data is in the breakDown package

set.seed(1313)

library(breakDown)
head(HR_data, 3)

Now let's create a random forest regression model for churn, the left variable.

library("randomForest")
model <- randomForest(factor(left)~., data = HR_data, family = "binomial", maxnodes = 5)

But how to understand which factors drive predictions for a single observation?

With the breakDown package!

Explanations for the linear predictor.

library(ggplot2)

predict.function <- function(model, new_observation) predict(model, new_observation, type="prob")[,2]
predict.function(model, HR_data[11,-7])

explain_1 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
                    predict.function = predict.function, 
                    direction = "down")
explain_1
plot(explain_1) + ggtitle("breakDown plot  (direction=down) for randomForest model")

explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
                    predict.function = predict.function, 
                    direction = "up")
explain_2
plot(explain_2) + ggtitle("breakDown plot (direction=up) for randomForest model")


pbiecek/breakDown documentation built on March 15, 2024, 10:46 a.m.