knitr::opts_chunk$set(collapse = TRUE, comment = "#>")
Brightbox contains functions that tackle the problem of inspecting internals for any blackbox supervised learner. The package is designed to work with the caret package as well as any model that is an ensemble of caret
learners. As of writing, the approaches implemented in the package are partial dependency plots (run_partial_dependency
) and marginal variable importance (calculate_marginal_vimp
).
> devtools::install_github('breather/brightbox')
Partial dependency plots are a technique for visualizing the effect of a single feature on the response, marginalizing over the values of all other features. They are the visual equivalent of a coefficient in linear regression (and in fact, the partial dependency plot for a coefficient in a linear model will be a straight line).
# Example of partial dependency plots for a linear model library(data.table) dt <- data.table(a = 1:3, b = 2:4, c = c(8, 11, 14)) lm1 <- lm(c ~ a + b - 1, dt) lm1$coefficients
# Note that the slopes of the plotted lines match the coefficients library(brightbox) pd <- run_partial_dependency(feature_dt = dt[, c("a", "b")], model_list = list(linear_model = lm1))
The advantage of partial dependency plots shine when there is a non-linear relationship between the feature in question and the response. In such cases, a linear model will hide the true relationship due to its underlying assumptions. Ideally we would like to use a more flexible model with fewer assumptions but maintain interpretability.
The next section will work through an example dataset to demonstrate brightbox
's features with respect to partial dependency plots.
# First, load the data library(data.table) library(mlbench) data(BostonHousing, package = "mlbench") head(BostonHousing)
# Split into features and response boston_dt <- data.table(BostonHousing) x <- boston_dt[ , -"medv", with = FALSE] y <- boston_dt$medv # Prep the data (numeric columns are friendlier) x[, chas := as.numeric(chas)]
The dataset BostonHousing
contains housing data for 506 census tracts of Boston from the 1970 census. medv
is the response variable representing the median value of houses in each census tract (in USD 1000's). See ??BostonHousing
for additional details.
We will train two blackbox learners on this data and see if we can interpret the patterns each model learned.
library(caret) # Train an xgboost model xgb <- train(x = x, y = y, method = "xgbTree", metric = "RMSE") # Train a single-layer neural net nn <- train(x = x, y = y, method = "nnet", metric = "RMSE", preProc = c("center", "scale"), tuneGrid = expand.grid(size = c(10, 15, 20), decay = c(0, 5e-4, 0.05)), linout = TRUE, trace = FALSE)
From the two trained models we can use Brightbox to inspect their internals and trivially, the internals of an ensemble of these models. In this example, we construct an ensemble that is the median of the xgboost and neural net models.
library(brightbox) # Generate partial dependency plots for each feature pd <- run_partial_dependency(x, model_list = list(xgboost = xgb, nnet = nn), ensemble_colname = "ensemble", ensemble_fcn = median)
In the returned plot we can see how medv
changes with respect to each feature for every model. As a result, there are quite a few things we can learn by inspecting the plot.
rm
and nox
for example).age
, rm
, dis
)As was hinted in points 2 and 3 above, partial dependency plots can help us determine feature importance by inspecting the partial dependence range and variance. run_partial_dependency
returns a data.table with such calculations pre-computed.
# pd was previously saved from run_partial_dependency head(pd)
pd
contains the data necessary to reconstruct the partial dependency plots (columns feature
, feature_val
, model
, and prediction
) as well as a variable importance column (vimp
). Variable importance is calculated as the y-axis range for a given feature and model, with the ensemble model chosen by default (TODO: incorporate variance into the variable importance calculation).
# Inspect the variable importance of each feature unique(pd[, list(feature, vimp)])
From here we may be interested in plotting just the 10 most important features.
pd_top <- run_partial_dependency(x, model_list = list(xgboost = xgb, nnet = nn), feature_cols = unique(pd[["feature"]])[1:10], ensemble_colname = "ensemble", ensemble_fcn = median)
We may wish to check the stability of the relationship between a feature and the partial dependence. How do we know that the relationship isn't just by chance, an artifact of a model overfitting?
We propose an approach in which many models are trained on different subsamples of the training data. Both visual tests and automated tests give us good insight as to the variance in the partial dependence relationship.
Let's start by training 50 xgboost models on random subsamples of the BostonHousing
dataset and store the resulting models in a list:
library(caret) # Train 50 xgboost models xgb_list <- list() for (i in 1:50) { ind <- sample(x = 1:nrow(x), size = 404) # randomly sample ~80% of rows xgb_list[[i]] <- train(x = x[ind], y = y[ind], method = "xgbTree", metric = "RMSE", trControl = trainControl(method = "none"), tuneGrid = xgb$bestTune) } names(xgb_list) <- paste("xgb", 1:50, sep = "_")
Now that we have our 50 seperate models stored in list, we can try plotting the resulting output. Instead of plotting all 50 models in the same plot, which would turn into a mess, we plot specific quantiles of the model predictions, namely the 5th, 50th, and 95th quantiles.
# return list of partial dependency data.tables, one data.table for each feature pd_list <- loop_calculate_partial_dependency(feature_dt = x, model_list = xgb_list) pd <- do.call(rbind, pd_list) # rbind data.tables into one long data.table pd_quant <- pd[model != "ensemble", list(lcl = quantile(prediction, .05), ensemble = median(prediction), ucl = quantile(prediction, .95)), by = c("feature", "feature_val")] ggplot(data = pd_quant, aes(x = feature_val, y = prediction)) + geom_line(aes(y = ensemble), size = 1.2, color = "red") + geom_line(aes(y = ucl), size = 0.5, linetype = "dashed") + geom_line(aes(y = lcl), size = 0.5, linetype = "dashed") + facet_wrap(~feature, scales = "free") + labs(x = "Value Cutpoint", y = "Partial Dependence")
We can see that for variables like dis
, tax
, and rm
, there is little spread between the quantiles, indicating very stable results regardless of the subsample the models were trained on. Conversely, we can conclude that variables like crim
and chas
have a high variance between models, at least for certain ranges of the variables; we should treat any inference about these variables with caution.
While the above is a useful visual test, we can also get an automated ranking of variable importance normalized by model variance. The function calculate_pd_vimp_normed
performs the following operations:
calculate_pd_vimp
median
of the standard deviation vector.calculate_pd_vimp
in step 1 by the summary value of the standard deviation in step 3.We simply pass pd_list
to the function calculate_pd_vimp_normed
using lapply
to get the variance adjusted ranking of all variables in the dataset:
# apply calculate_pd_vimp_normed to each element of pd_list vimp_normed_vec <- sapply(pd_list, calculate_pd_vimp_normed) #Print descending order of variable importance print(vimp_normed_vec[order(-vimp_normed_vec)])
As we can see, the ranking is consistent with what we can visually identify as being the most stable relationships.
While the function run_partial_dependency
is the primary interface for partial dependency plots, it is composed of the following functions which can be useful if you want to avoid redundant calculations.
calculate_partial_dependency
loop_calculate_partial_dependency
facet_plot_fcn
loop_plot_fcn
plot_partial_dependency
calculate_pd_vimp
(TODO: change this to a walkthrough instead)
calculate_marginal_vimp(x, y, method, loss_metric, resampling_indices, tuneGrid, trControl, vars = names(x), allow_parallel = FALSE, ...)
Arguments
: x
: data.table containing predictor variables
: y
: vector containing target variable
: method
: character string defining method to pass to caret
: loss_metric
: character. Loss metric to evaluate accuracy of model
: resampling_indices
: a list of integer vectors corresponding to the row indices used for each resampling iteration
: tuneGrid
: a data.frame containing hyperparameter values for caret. Should only contain one value for each hyperparameter. Set to NULL if caret method does not have any hyperparameter values.
: trControl
: trainControl object to be passed to caret train
.
: vars
: character vector specifying variables for which to determine marginal importance Defaults to all predictor variables in x.
: allow_parallel
: boolean for parallel execution. If set to TRUE, user must specify parallel backend in their R session to take advantage of multiple cores. Defaults to FALSE.
: ...
: additional arguments to pass to caret train
A caret model is trained on training data according to resampling indices specified by the user, and error is calculated on out of sample data. Variable importance is determined by calculating the change in out of sample model performance when a variable is removed relative to baseline out of sample performance when all variables are included.
The workhorse function for Marginal Variable Importance is implemented in calculate_marginal_vimp
(source).
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.