knitr::opts_chunk$set(
  comment = "#>",
  warning = FALSE,
  message = FALSE,
  cache = TRUE
)

Here we will use the dragons data to present the ceterisParibus2 package for regression models.

# devtools::install_github("ModelOriented/DALEX2")
library(DALEX2)
library(ggplot2)
library(ceterisParibus2)

head(dragons)
new_observation <- dragons_test[1,]
new_observation

Select neighbourhood sample, random sample, specific variables.

similar_dragons <- select_neighbours(dragons_test, new_observation, n = 10)
similar_dragons

random_dragons <- select_sample(dragons_test, n = 10)
random_dragons

variable_splits <- calculate_variable_split(dragons, variables = c("scars", "weight"))

Linear regression

First, we fit a model.

m_lm <- lm(life_length ~ . , data = dragons)

To calculete individual variable profiles (ceteris paribus profiles), i.e. series of predictions from a model calculated for observations with altered single coordinate, we use the ceterisParibus2 package.

We create an object of the ceteris_paribus_explainer class.

ivp_lm <- individual_variable_profile(m_lm,
                            data = dragons_test,
                            new_observation =  new_observation)

Now we can plot individual profiles

plot(ivp_lm)

Calculate oscillations for individual variable profiles.

individual_variable_oscillations(ivp_lm)

Ceteris paribus plots for neighbourhood sample, random sample, selected variables.

ivp_lm_neighbours <- individual_variable_profile(m_lm,
                            data = similar_dragons,
                            new_observation =  new_observation)
plot(ivp_lm_neighbours)

ivp_lm_random <- individual_variable_profile(m_lm,
                            data = random_dragons,
                            new_observation =  new_observation,
                            variable_splits = variable_splits)
plot(ivp_lm_random)

Data frame with profiles.

profiles <- calculate_variable_profile(data = new_observation,
                                       variable_splits = variable_splits,
                                       model = m_lm)
head(profiles)

For another type of models we proceed analogously.

randomForest

library(randomForest)

m_rf <- randomForest(life_length ~ . , data = dragons)

ivp_rf <- individual_variable_profile(m_rf,
                            data = dragons_test,
                            new_observation =  new_observation)

plot(ivp_rf)
individual_variable_oscillations(ivp_rf)

SVM

library(e1071)

m_svm <- svm(life_length ~ . , data = dragons)

ivp_svm <- individual_variable_profile(m_svm,
                            data = dragons_test,
                            new_observation =  new_observation)

plot(ivp_svm)
individual_variable_oscillations(ivp_svm)

knn

library(caret)

m_knn <- knnreg(life_length ~ . , data = dragons, k = 5)

ivp_knn <- individual_variable_profile(m_knn,
                            data = dragons_test,
                            new_observation =  new_observation)

plot(ivp_knn)
individual_variable_oscillations(ivp_knn)

nnet

When you use nnet package for regression, remember to normalize the resposne variable, in such a way that it is from interval $(0,1)$.

In this case creating own predict function is also needed.

library(nnet)

x <- max(abs(dragons$life_length))
digits <- floor(log10(x))
normalizing_factor <- round(x, -digits)
m_nnet <- nnet(life_length/normalizing_factor ~ . , data = dragons, size = 10, linout = TRUE)

p_fun <- function(model, new_observation){
  predict(model, newdata = new_observation)*normalizing_factor
}

ivp_nnet <- individual_variable_profile(m_nnet,
                            data = dragons_test,
                            new_observation =  new_observation,
                            predict_function = p_fun)

plot(ivp_nnet)
individual_variable_oscillations(ivp_nnet)

Several models at once

To produce plot with many models in one graph, use argument color = _label_.

plot(ivp_lm, ivp_rf, ivp_svm, ivp_knn, ivp_nnet, color = "_label_")


pbiecek/ceterisParibus2 documentation built on Sept. 16, 2019, 6:26 p.m.