We can distill a general machine learning model into an interpretable additive
combination of the PDP functions of the model. This is done by creating a linear
combination of the univariate PDP curves with nonnegative weights, fit to the
predictions of the original model. To create a surrogate model from the original
model, we first create an Interpreter
object.
# Load the required packages library(distillML) library(Rforestry) library(ggplot2) library(dplyr)
# Load the required packages library(distillML) library(Rforestry) # Load in data data("iris") set.seed(491) data <- iris # Split data into train and test sets train <- data[1:100, ] test <- data[101:150, ] # Train a random forest on the data set forest <- forestry(x=train[, -1], y=train[, 1]) # Create a predictor wrapper for the forest forest_predictor <- Predictor$new(model = forest, data=data, y="Sepal.Length", task = "regression") forest_interpret <- Interpreter$new(predictor = forest_predictor)
To create a surrogate model from the Interpreter
object, we call the function
distill
, which has a set of important arguments for determining how the
surrogate model makes its predictions. The underlying method for fitting these
weights is the package glmnnet
, which enables us to enforce sparsity with
lasso or ridge regression with cross-validation. Below, we give brief explanations
of the key arguments in the distill
method.
features
: This argument is a vector of feature indices. The index of the feature
is the index of the feature's name in the Interpreter
object's vector
of features, which can be accessed in our example through forest_interpret$features$
.
The default value for features is to include the indices of all valid features.cv
: This argument takes a boolean value and determines whether we want to cross-validate
the weights we derive for each PDP curve. This argument by default is FALSE
, and
should only be set to TRUE
if we want induce sparsity through lasso or ridge regression.snap.grid
: This argument takes a boolean value and determines if the surrogate model
should make predictions using approximations with previously calculated
values. By default, this is TRUE
. Because of the computationally expensive
calculations involved with PDP curves, this should only be switched to FALSE
for small training datasets or a small user-specified samples
argument when
declaring the Interpreter
object.snap.train
: This argument takes a boolean value and determines if we make approximations
through the marginal distribution of the subsampled training data, or the
equally spaced grid points. By default, this is TRUE
, which means that
the surrogate model uses the marginal distribution of the subsampled training
data. If switched to FALSE
, the distillation process becomes more computationally
expensive due to having to predict these equally-spaced grid points. # Distilling the Model With Default Settings distilled_model <- distill(forest_interpret)
The distill
method returns a Surrogate
object, which can be used for further
predictions. As we see in the coefficients, each value of the categorical variable is
one-hot encoded in the weight-fitting process. In this case, we get three
different coefficients for the Species
feature, one for each of the three
species of iris.
print(distilled_model) # get weights for each PDP curve print(distilled_model$weights)
To induce sparsity in the distilled surrogate models, we can also introduce
arguments into the weight-fitting process with cv.glmnet
by specifying
parameters with the params.cv.glmnet
argument. This argument is a list of
arguments that the function cv.glmnet
can take. An equivalent argument
params.glment
exists in the distill
method for user-specified fitting of
weights when cv = F
with no penalty factor.
# Sparse Model sparse_model <- distill(forest_interpret, cv = T, params.cv.glmnet = list(lower.limits = 0, intercept = F, alpha = 1)) print(distilled_model$weights) print(sparse_model$weights)
While in this case, no coefficient is set to 0, we note that the coefficients in the sparse model tend to be smaller than in the original distilled model.
We use the predict
function to make new predictions with the distilled surrogate
models. Note that the output of the surrogate model is a one-column dataframe
of the new predictions.
# make predictions original.preds <- predict(forest, test[,-1]) distill.preds <- predict(distilled_model, test[,-1])[,1] sparse.preds <- predict(sparse_model, test[,-1])[,1] # compare MSE rmse <- function(preds){ return(sqrt(mean((test[,1] - preds)^2))) } print(data.frame(original = rmse(original.preds), distill = rmse(distill.preds), distill.sparse = rmse(sparse.preds)))
As measured by the test set, we see that both distilled models outperform the original random forest model, indicating that the distilled models generalize better out of sample. This varies based on use-case. We can compare how closely the distilled predictions match the original predictions by plotting them against each other below:
# create dataframe of data we'd like to plot plot.data <- data.frame(original.preds, distill.preds, sparse.preds) # plots for both default distilled model and sparse model default.plot <- ggplot(data = plot.data, aes(x = original.preds, y = distill.preds)) + xlim(min(c(original.preds,distill.preds)),max(c(original.preds,distill.preds)))+ ylim(min(c(original.preds,distill.preds)),max(c(original.preds,distill.preds)))+ geom_point() + geom_abline(col = "red") sparse.plot <- ggplot(data = plot.data, aes(x = original.preds, y = sparse.preds)) + xlim(min(c(original.preds,sparse.preds)),max(c(original.preds,sparse.preds)))+ ylim(min(c(original.preds,sparse.preds)),max(c(original.preds,sparse.preds)))+ geom_point() + geom_abline(col = "red") plot(default.plot) plot(sparse.plot)
Neither model matches exceptionally well to the original model's predictions, as shown by these plots. In this case, we see that the regularized, sparse distilled model matches the original random forest's predictions better than the default settings, which corroborates their similar out-of-sample performance.
For further details on the distill
method, please refer to the documentation
provided in the "References" section.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.