feature_importance: Feature Importance

View source: R/feature_importance.R

feature_importanceR Documentation

Feature Importance

Description

This function calculates permutation based feature importance. For this reason it is also called the Variable Dropout Plot.

Usage

feature_importance(x, ...)

## S3 method for class 'explainer'
feature_importance(
  x,
  loss_function = DALEX::loss_root_mean_square,
  ...,
  type = c("raw", "ratio", "difference"),
  n_sample = NULL,
  B = 10,
  variables = NULL,
  variable_groups = NULL,
  N = n_sample,
  label = NULL
)

## Default S3 method:
feature_importance(
  x,
  data,
  y,
  predict_function = predict,
  loss_function = DALEX::loss_root_mean_square,
  ...,
  label = class(x)[1],
  type = c("raw", "ratio", "difference"),
  n_sample = NULL,
  B = 10,
  variables = NULL,
  N = n_sample,
  variable_groups = NULL
)

Arguments

x

an explainer created with function DALEX::explain(), or a model to be explained.

...

other parameters

loss_function

a function thet will be used to assess variable importance

type

character, type of transformation that should be applied for dropout loss. "raw" results raw drop losses, "ratio" returns drop_loss/drop_loss_full_model while "difference" returns drop_loss - drop_loss_full_model

n_sample

alias for N held for backwards compatibility. number of observations that should be sampled for calculation of variable importance.

B

integer, number of permutation rounds to perform on each variable. By default it's 10.

variables

vector of variables. If NULL then variable importance will be tested for each variable from the data separately. By default NULL

variable_groups

list of variables names vectors. This is for testing joint variable importance. If NULL then variable importance will be tested separately for variables. By default NULL. If specified then it will override variables

N

number of observations that should be sampled for calculation of variable importance. If NULL then variable importance will be calculated on whole dataset (no sampling).

label

name of the model. By default it's extracted from the class attribute of the model

data

validation dataset, will be extracted from x if it's an explainer NOTE: It is best when target variable is not present in the data

y

true labels for data, will be extracted from x if it's an explainer

predict_function

predict function, will be extracted from x if it's an explainer

Details

Find more details in the Feature Importance Chapter.

Value

an object of the class feature_importance

References

Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. https://ema.drwhy.ai/

Examples

library("DALEX")
library("ingredients")

model_titanic_glm <- glm(survived ~ gender + age + fare,
                         data = titanic_imputed, family = "binomial")

explain_titanic_glm <- explain(model_titanic_glm,
                               data = titanic_imputed[,-8],
                               y = titanic_imputed[,8])

fi_glm <- feature_importance(explain_titanic_glm, B = 1)
plot(fi_glm)



fi_glm_joint1 <- feature_importance(explain_titanic_glm,
                   variable_groups = list("demographics" = c("gender", "age"),
                   "ticket_type" = c("fare")),
                   label = "lm 2 groups")

plot(fi_glm_joint1)

fi_glm_joint2 <- feature_importance(explain_titanic_glm,
                   variable_groups = list("demographics" = c("gender", "age"),
                                          "wealth" = c("fare", "class"),
                                          "family" = c("sibsp", "parch"),
                                          "embarked" = "embarked"),
                   label = "lm 5 groups")

plot(fi_glm_joint2, fi_glm_joint1)

library("ranger")
model_titanic_rf <- ranger(survived ~., data = titanic_imputed, probability = TRUE)

explain_titanic_rf <- explain(model_titanic_rf,
                              data = titanic_imputed[,-8],
                              y = titanic_imputed[,8],
                              label = "ranger forest",
                              verbose = FALSE)

fi_rf <- feature_importance(explain_titanic_rf)
plot(fi_rf)

fi_rf <- feature_importance(explain_titanic_rf, B = 6) # 6 replications
plot(fi_rf)

fi_rf_group <- feature_importance(explain_titanic_rf,
                   variable_groups = list("demographics" = c("gender", "age"),
                   "wealth" = c("fare", "class"),
                   "family" = c("sibsp", "parch"),
                   "embarked" = "embarked"),
                   label = "rf 4 groups")

plot(fi_rf_group, fi_rf)

HR_rf_model <- ranger(status ~., data = HR, probability = TRUE)

explainer_rf  <- explain(HR_rf_model, data = HR, y = HR$status,
                         model_info = list(type = 'multiclass'))

fi_rf <- feature_importance(explainer_rf, type = "raw",
                            loss_function = DALEX::loss_cross_entropy)
head(fi_rf)
plot(fi_rf)

HR_glm_model <- glm(status == "fired"~., data = HR, family = "binomial")
explainer_glm <- explain(HR_glm_model, data = HR, y = as.numeric(HR$status == "fired"))
fi_glm <- feature_importance(explainer_glm, type = "raw",
                             loss_function = DALEX::loss_root_mean_square)
head(fi_glm)
plot(fi_glm)



ingredients documentation built on Jan. 15, 2023, 5:09 p.m.