View source: R/perm_importance.R
perm_importance | R Documentation |
Calculates permutation importance for a set of features or a set of feature groups.
By default, importance is calculated for all columns in X
(except column names
used as response y
or case weight w
).
perm_importance(object, ...)
## Default S3 method:
perm_importance(
object,
X,
y,
v = NULL,
pred_fun = stats::predict,
loss = "squared_error",
m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE,
n_max = 10000L,
w = NULL,
verbose = TRUE,
...
)
## S3 method for class 'ranger'
perm_importance(
object,
X,
y,
v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
loss = "squared_error",
m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE,
n_max = 10000L,
w = NULL,
verbose = TRUE,
...
)
## S3 method for class 'explainer'
perm_importance(
object,
X = object[["data"]],
y = object[["y"]],
v = NULL,
pred_fun = object[["predict_function"]],
loss = "squared_error",
m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE,
n_max = 10000L,
w = object[["weights"]],
verbose = TRUE,
...
)
object |
Fitted model object. |
... |
Additional arguments passed to |
X |
A data.frame or matrix serving as background dataset. |
y |
Vector/matrix of the response, or the corresponding column names in |
v |
Vector of feature names, or named list of feature groups.
The default ( |
pred_fun |
Prediction function of the form |
loss |
One of "squared_error", "logloss", "mlogloss", "poisson",
"gamma", "absolute_error", "classification_error". Alternatively, a loss function
can be provided that turns observed and predicted values into a numeric vector or
matrix of unit losses of the same length as |
m_rep |
Number of permutations (default 4). |
agg_cols |
Should multivariate losses be summed up? Default is |
normalize |
Should importance statistics be divided by average loss?
Default is |
n_max |
If |
w |
Optional vector of case weights. Can also be a column name of |
verbose |
Should a progress bar be shown? The default is |
The permutation importance of a feature is defined as the increase in the average
loss when shuffling the corresponding feature values before calculating predictions.
By default, the process is repeated m_rep = 4
times, and the results are averaged.
In most of the cases, importance values should be derived from an independent test
data set. Set normalize = TRUE
to get relative increases in average loss.
An object of class "hstats_matrix" containing these elements:
M
: Matrix of statistics (one column per prediction dimension), or NULL
.
SE
: Matrix with standard errors of M
, or NULL
.
Multiply with sqrt(m_rep)
to get standard deviations instead.
Currently, supported only for perm_importance()
.
m_rep
: The number of repetitions behind standard errors SE
, or NULL
.
Currently, supported only for perm_importance()
.
statistic
: Name of the function that generated the statistic.
description
: Description of the statistic.
perm_importance(default)
: Default method.
perm_importance(ranger)
: Method for "ranger" models.
perm_importance(explainer)
: Method for DALEX "explainer".
The default loss
is the "squared_error". Other choices:
"absolute_error": The absolute error is the loss corresponding to median regression.
"poisson": Unit Poisson deviance, i.e., the loss function used in
Poisson regression. Actual values y
and predictions must be non-negative.
"gamma": Unit gamma deviance, i.e., the loss function of Gamma regression.
Actual values y
and predictions must be positive.
"logloss": The Log Loss is the loss function used in logistic regression,
and the top choice in probabilistic binary classification. Responses y
and
predictions must be between 0 and 1. Predictions represent probabilities of
having a "1".
"mlogloss": Multi-Log-Loss is the natural loss function in probabilistic multi-class
situations. If there are K classes and n observations, the predictions form
a (n x K) matrix of probabilities (with row-sums 1).
The observed values y
are either passed as (n x K) dummy matrix,
or as discrete vector with corresponding levels.
The latter case is turned into a dummy matrix by a fast version of
model.matrix(~ as.factor(y) + 0)
.
"classification_error": Misclassification error. Both the
observed values y
and the predictions can be character/factor. This
loss function can be used in non-probabilistic classification settings.
BUT: Probabilistic classification (with "mlogloss") is clearly preferred in most
situations.
A function with signature f(actual, predicted)
, returning a numeric
vector or matrix of the same length as the input.
Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. Arxiv.
# MODEL 1: Linear regression
fit <- lm(Sepal.Length ~ ., data = iris)
s <- perm_importance(fit, X = iris, y = "Sepal.Length")
s
s$M
s$SE # Standard errors are available thanks to repeated shuffling
plot(s)
plot(s, err_type = "SD") # Standard deviations instead of standard errors
# Groups of features can be passed as named list
v <- list(petal = c("Petal.Length", "Petal.Width"), species = "Species")
s <- perm_importance(fit, X = iris, y = "Sepal.Length", v = v, verbose = FALSE)
s
plot(s)
# MODEL 2: Multi-response linear regression
fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
s <- perm_importance(fit, X = iris[, 3:5], y = iris[, 1:2], normalize = TRUE)
s
plot(s)
plot(s, swap_dim = TRUE, top_m = 2)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.