knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
In this vignette, we present a local variable importance measure based on Ceteris Paribus profiles for random forest regression model.
library("ggplot2")
We work on Apartments dataset from DALEX
package.
library("DALEX") data(apartments) head(apartments)
Now, we define a random forest regression model and use explain from DALEX
.
library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price)
We need to specify an observation. Let consider a new apartment with the following attributes. Moreover, we calculate predict value for this new observation.
new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) predict(apartments_rf_model, new_apartment)
Let see the Ceteris Paribus Plots calculated with DALEX::predict_profile()
function.
The CP also can be calculated with DALEX::individual_profile()
or ingredients::ceteris_paribus()
.
library("ingredients") profiles <- predict_profile(explainer_rf, new_apartment) plot(profiles) + show_observations(profiles)
Now, we calculated a measure of local variable importance via oscillation based on Ceteris Paribus profiles. We use variant with all parameters equals to TRUE.
library("vivo") measure <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE)
plot(measure)
For the new observation the most important variable is surface, then floor, construction.year and no.rooms.
We calculated local variable importance for different parameters and we can plot together, on bar plot or lines plot.
measure_2 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = FALSE, point = TRUE, density = TRUE) measure_3 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = FALSE, point = TRUE, density = FALSE)
plot(measure, measure_2, measure_3, color = "_label_method_")
plot(measure, measure_2, measure_3, color = "_label_method_", type = "lines")
Let created a linear regression model and explain
object.
apartments_lm_model <- lm(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_lm <- explain(apartments_lm_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price)
We calculated Ceteris Paribus profiles and measure.
profiles_lm <- predict_profile(explainer_lm, new_apartment) measure_lm <- local_variable_importance(profiles_lm, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE)
plot(measure, measure_lm, color = "_label_model_", type = "lines")
Now we can see the order of importance of variables by model for selected observation.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.