Model Distillation

We can distill a general machine learning model into an interpretable additive combination of the PDP functions of the model. This is done by creating a linear combination of the univariate PDP curves with nonnegative weights, fit to the predictions of the original model. To create a surrogate model from the original model, we first create an Interpreter object.

# Load the required packages
library(distillML)
library(Rforestry)
library(ggplot2)
library(dplyr)
# Load the required packages
library(distillML)
library(Rforestry)

# Load in data 
data("iris")
set.seed(491)
data <- iris

# Split data into train and test sets
train <- data[1:100, ]
test <- data[101:150, ]

# Train a random forest on the data set
forest <- forestry(x=train[, -1],
                   y=train[, 1])

# Create a predictor wrapper for the forest
forest_predictor <- Predictor$new(model = forest,
                                  data=data,
                                  y="Sepal.Length",
                                  task = "regression")

forest_interpret <- Interpreter$new(predictor = forest_predictor)

To create a surrogate model from the Interpreter object, we call the function distill, which has a set of important arguments for determining how the surrogate model makes its predictions. The underlying method for fitting these weights is the package glmnnet, which enables us to enforce sparsity with lasso or ridge regression with cross-validation. Below, we give brief explanations of the key arguments in the distill method.

# Distilling the Model With Default Settings
distilled_model <- distill(forest_interpret)

The distill method returns a Surrogate object, which can be used for further predictions. As we see in the coefficients, each value of the categorical variable is one-hot encoded in the weight-fitting process. In this case, we get three different coefficients for the Species feature, one for each of the three species of iris.

print(distilled_model)

# get weights for each PDP curve
print(distilled_model$weights)

Inducing Sparsity in Distilled Surrogate Models

To induce sparsity in the distilled surrogate models, we can also introduce arguments into the weight-fitting process with cv.glmnet by specifying parameters with the params.cv.glmnet argument. This argument is a list of arguments that the function cv.glmnet can take. An equivalent argument params.glment exists in the distill method for user-specified fitting of weights when cv = F with no penalty factor.

# Sparse Model
sparse_model <- distill(forest_interpret, 
                           cv = T, 
                           params.cv.glmnet = list(lower.limits = 0,
                                                   intercept = F,
                                                   alpha = 1))
print(distilled_model$weights)
print(sparse_model$weights)

While in this case, no coefficient is set to 0, we note that the coefficients in the sparse model tend to be smaller than in the original distilled model.

Predictions Using Surrogate Models

We use the predict function to make new predictions with the distilled surrogate models. Note that the output of the surrogate model is a one-column dataframe of the new predictions.

# make predictions
original.preds <- predict(forest, test[,-1])
distill.preds <- predict(distilled_model, test[,-1])[,1]
sparse.preds <- predict(sparse_model, test[,-1])[,1]

# compare MSE
rmse <- function(preds){
  return(sqrt(mean((test[,1] - preds)^2)))
}

print(data.frame(original = rmse(original.preds),
                 distill = rmse(distill.preds),
                 distill.sparse = rmse(sparse.preds)))

As measured by the test set, we see that both distilled models outperform the original random forest model, indicating that the distilled models generalize better out of sample. This varies based on use-case. We can compare how closely the distilled predictions match the original predictions by plotting them against each other below:

# create dataframe of data we'd like to plot
plot.data <- data.frame(original.preds,
                        distill.preds,
                        sparse.preds)

# plots for both default distilled model and sparse model
default.plot <- ggplot(data = plot.data, aes(x = original.preds,
                                             y = distill.preds)) +
                xlim(min(c(original.preds,distill.preds)),max(c(original.preds,distill.preds)))+
                ylim(min(c(original.preds,distill.preds)),max(c(original.preds,distill.preds)))+
                geom_point() + geom_abline(col = "red")

sparse.plot <- ggplot(data = plot.data, aes(x = original.preds,
                                             y = sparse.preds)) +
                xlim(min(c(original.preds,sparse.preds)),max(c(original.preds,sparse.preds)))+
                ylim(min(c(original.preds,sparse.preds)),max(c(original.preds,sparse.preds)))+
                geom_point() + geom_abline(col = "red")

plot(default.plot)
plot(sparse.plot)

Neither model matches exceptionally well to the original model's predictions, as shown by these plots. In this case, we see that the regularized, sparse distilled model matches the original random forest's predictions better than the default settings, which corroborates their similar out-of-sample performance.

For further details on the distill method, please refer to the documentation provided in the "References" section.



forestry-labs/interpretability_sandbox documentation built on April 26, 2023, 4:14 p.m.