Description Usage Arguments Details Value References Examples

View source: R/feature_importance.R

This function calculates permutation based feature importance. For this reason it is also called the Variable Dropout Plot.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | ```
feature_importance(x, ...)
## S3 method for class 'explainer'
feature_importance(
x,
loss_function = DALEX::loss_root_mean_square,
...,
type = c("raw", "ratio", "difference"),
n_sample = NULL,
B = 10,
variables = NULL,
variable_groups = NULL,
N = n_sample,
label = NULL
)
## Default S3 method:
feature_importance(
x,
data,
y,
predict_function = predict,
loss_function = DALEX::loss_root_mean_square,
...,
label = class(x)[1],
type = c("raw", "ratio", "difference"),
n_sample = NULL,
B = 10,
variables = NULL,
N = n_sample,
variable_groups = NULL
)
``` |

`x` |
an explainer created with function |

`...` |
other parameters |

`loss_function` |
a function thet will be used to assess variable importance |

`type` |
character, type of transformation that should be applied for dropout loss.
"raw" results raw drop losses, "ratio" returns |

`n_sample` |
alias for |

`B` |
integer, number of permutation rounds to perform on each variable. By default it's |

`variables` |
vector of variables. If |

`variable_groups` |
list of variables names vectors. This is for testing joint variable importance.
If |

`N` |
number of observations that should be sampled for calculation of variable importance.
If |

`label` |
name of the model. By default it's extracted from the |

`data` |
validation dataset, will be extracted from |

`y` |
true labels for |

`predict_function` |
predict function, will be extracted from |

Find more details in the Feature Importance Chapter.

an object of the class `feature_importance`

Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. https://ema.drwhy.ai/

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | ```
library("DALEX")
library("ingredients")
model_titanic_glm <- glm(survived ~ gender + age + fare,
data = titanic_imputed, family = "binomial")
explain_titanic_glm <- explain(model_titanic_glm,
data = titanic_imputed[,-8],
y = titanic_imputed[,8])
fi_glm <- feature_importance(explain_titanic_glm, B = 1)
plot(fi_glm)
fi_glm_joint1 <- feature_importance(explain_titanic_glm,
variable_groups = list("demographics" = c("gender", "age"),
"ticket_type" = c("fare")),
label = "lm 2 groups")
plot(fi_glm_joint1)
fi_glm_joint2 <- feature_importance(explain_titanic_glm,
variable_groups = list("demographics" = c("gender", "age"),
"wealth" = c("fare", "class"),
"family" = c("sibsp", "parch"),
"embarked" = "embarked"),
label = "lm 5 groups")
plot(fi_glm_joint2, fi_glm_joint1)
library("ranger")
model_titanic_rf <- ranger(survived ~., data = titanic_imputed, probability = TRUE)
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed[,-8],
y = titanic_imputed[,8],
label = "ranger forest",
verbose = FALSE)
fi_rf <- feature_importance(explain_titanic_rf)
plot(fi_rf)
fi_rf <- feature_importance(explain_titanic_rf, B = 6) # 6 replications
plot(fi_rf)
fi_rf_group <- feature_importance(explain_titanic_rf,
variable_groups = list("demographics" = c("gender", "age"),
"wealth" = c("fare", "class"),
"family" = c("sibsp", "parch"),
"embarked" = "embarked"),
label = "rf 4 groups")
plot(fi_rf_group, fi_rf)
HR_rf_model <- ranger(status ~., data = HR, probability = TRUE)
explainer_rf <- explain(HR_rf_model, data = HR, y = HR$status,
model_info = list(type = 'multiclass'))
fi_rf <- feature_importance(explainer_rf, type = "raw",
loss_function = DALEX::loss_cross_entropy)
head(fi_rf)
plot(fi_rf)
HR_glm_model <- glm(status == "fired"~., data = HR, family = "binomial")
explainer_glm <- explain(HR_glm_model, data = HR, y = as.numeric(HR$status == "fired"))
fi_glm <- feature_importance(explainer_glm, type = "raw",
loss_function = DALEX::loss_root_mean_square)
head(fi_glm)
plot(fi_glm)
``` |

Embedding an R snippet on your website

Add the following code to your website.

For more information on customizing the embed code, read Embedding Snippets.