View source: R/feature_importance.R
feature_importance | R Documentation |
This function calculates permutation based feature importance. For this reason it is also called the Variable Dropout Plot.
feature_importance(x, ...) ## S3 method for class 'explainer' feature_importance( x, loss_function = DALEX::loss_root_mean_square, ..., type = c("raw", "ratio", "difference"), n_sample = NULL, B = 10, variables = NULL, variable_groups = NULL, N = n_sample, label = NULL ) ## Default S3 method: feature_importance( x, data, y, predict_function = predict, loss_function = DALEX::loss_root_mean_square, ..., label = class(x)[1], type = c("raw", "ratio", "difference"), n_sample = NULL, B = 10, variables = NULL, N = n_sample, variable_groups = NULL )
x |
an explainer created with function |
... |
other parameters |
loss_function |
a function thet will be used to assess variable importance |
type |
character, type of transformation that should be applied for dropout loss.
"raw" results raw drop losses, "ratio" returns |
n_sample |
alias for |
B |
integer, number of permutation rounds to perform on each variable. By default it's |
variables |
vector of variables. If |
variable_groups |
list of variables names vectors. This is for testing joint variable importance.
If |
N |
number of observations that should be sampled for calculation of variable importance.
If |
label |
name of the model. By default it's extracted from the |
data |
validation dataset, will be extracted from |
y |
true labels for |
predict_function |
predict function, will be extracted from |
Find more details in the Feature Importance Chapter.
an object of the class feature_importance
Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. https://ema.drwhy.ai/
library("DALEX") library("ingredients") model_titanic_glm <- glm(survived ~ gender + age + fare, data = titanic_imputed, family = "binomial") explain_titanic_glm <- explain(model_titanic_glm, data = titanic_imputed[,-8], y = titanic_imputed[,8]) fi_glm <- feature_importance(explain_titanic_glm, B = 1) plot(fi_glm) fi_glm_joint1 <- feature_importance(explain_titanic_glm, variable_groups = list("demographics" = c("gender", "age"), "ticket_type" = c("fare")), label = "lm 2 groups") plot(fi_glm_joint1) fi_glm_joint2 <- feature_importance(explain_titanic_glm, variable_groups = list("demographics" = c("gender", "age"), "wealth" = c("fare", "class"), "family" = c("sibsp", "parch"), "embarked" = "embarked"), label = "lm 5 groups") plot(fi_glm_joint2, fi_glm_joint1) library("ranger") model_titanic_rf <- ranger(survived ~., data = titanic_imputed, probability = TRUE) explain_titanic_rf <- explain(model_titanic_rf, data = titanic_imputed[,-8], y = titanic_imputed[,8], label = "ranger forest", verbose = FALSE) fi_rf <- feature_importance(explain_titanic_rf) plot(fi_rf) fi_rf <- feature_importance(explain_titanic_rf, B = 6) # 6 replications plot(fi_rf) fi_rf_group <- feature_importance(explain_titanic_rf, variable_groups = list("demographics" = c("gender", "age"), "wealth" = c("fare", "class"), "family" = c("sibsp", "parch"), "embarked" = "embarked"), label = "rf 4 groups") plot(fi_rf_group, fi_rf) HR_rf_model <- ranger(status ~., data = HR, probability = TRUE) explainer_rf <- explain(HR_rf_model, data = HR, y = HR$status, model_info = list(type = 'multiclass')) fi_rf <- feature_importance(explainer_rf, type = "raw", loss_function = DALEX::loss_cross_entropy) head(fi_rf) plot(fi_rf) HR_glm_model <- glm(status == "fired"~., data = HR, family = "binomial") explainer_glm <- explain(HR_glm_model, data = HR, y = as.numeric(HR$status == "fired")) fi_glm <- feature_importance(explainer_glm, type = "raw", loss_function = DALEX::loss_root_mean_square) head(fi_glm) plot(fi_glm)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.