knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
Here we will use the HR churn data (https://www.kaggle.com/) to present the breakDown package for randomForest
models.
The data is in the breakDown
package
set.seed(1313) library(breakDown) head(HR_data, 3)
Now let's create a random forest regression model for churn, the left
variable.
library("randomForest") model <- randomForest(factor(left)~., data = HR_data, family = "binomial", maxnodes = 5)
But how to understand which factors drive predictions for a single observation?
With the breakDown
package!
Explanations for the linear predictor.
library(ggplot2) predict.function <- function(model, new_observation) predict(model, new_observation, type="prob")[,2] predict.function(model, HR_data[11,-7]) explain_1 <- broken(model, HR_data[11,-7], data = HR_data[,-7], predict.function = predict.function, direction = "down") explain_1 plot(explain_1) + ggtitle("breakDown plot (direction=down) for randomForest model") explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7], predict.function = predict.function, direction = "up") explain_2 plot(explain_2) + ggtitle("breakDown plot (direction=up) for randomForest model")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.