vi | R Documentation |
Compute variable importance scores for the predictors in a model.
vi(object, ...)
## Default S3 method:
vi(
object,
method = c("model", "firm", "permute", "shap"),
feature_names = NULL,
abbreviate_feature_names = NULL,
sort = TRUE,
decreasing = TRUE,
scale = FALSE,
rank = FALSE,
...
)
object |
A fitted model object (e.g., a
randomForest object) or an object that inherits
from class |
... |
Additional optional arguments to be passed on to vi_model, vi_firm, vi_permute, or vi_shap; see their respective help pages for details. |
method |
Character string specifying the type of variable importance (VI) to compute. Current options are:
|
feature_names |
Character string giving the names of the predictor variables (i.e., features) of interest. |
abbreviate_feature_names |
Integer specifying the length at which to
abbreviate feature names. Default is |
sort |
Logical indicating whether or not to order the sort the variable
importance scores. Default is |
decreasing |
Logical indicating whether or not the variable importance
scores should be sorted in descending ( |
scale |
Logical indicating whether or not to scale the variable
importance scores so that the largest is 100. Default is |
rank |
Logical indicating whether or not to rank the variable
importance scores (i.e., convert to integer ranks). Default is |
A tidy data frame (i.e., a tibble object) with two columns:
Variable
- the corresponding feature name;
Importance
- the associated importance, computed as the average change in
performance after a random permutation (or permutations, if nsim > 1
) of
the feature in question.
For lm/glm-like objects, whenever
method = "model"
, the sign (i.e., POS/NEG) of the original coefficient is
also included in a column called Sign
.
If method = "permute"
and nsim > 1
, then an additional column (StDev
)
containing the standard deviation of the individual permutation scores for
each feature is also returned; this helps assess the stability/variation of
the individual permutation importance for each feature.
#
# A projection pursuit regression example
#
# Load the sample data
data(mtcars)
# Fit a projection pursuit regression model
mtcars.ppr <- ppr(mpg ~ ., data = mtcars, nterms = 1)
# Prediction wrapper that tells vi() how to obtain new predictions from your
# fitted model
pfun <- function(object, newdata) predict(object, newdata = newdata)
# Compute permutation-based variable importance scores
set.seed(1434) # for reproducibility
(vis <- vi(mtcars.ppr, method = "permute", target = "mpg", nsim = 10,
metric = "rmse", pred_wrapper = pfun, train = mtcars))
# Plot variable importance scores
vip(vis, include_type = TRUE, all_permutations = TRUE,
geom = "point", aesthetics = list(color = "forestgreen", size = 3))
#
# A binary classification example
#
## Not run:
library(rpart) # for classification and regression trees
# Load Wisconsin breast cancer data; see ?mlbench::BreastCancer for details
data(BreastCancer, package = "mlbench")
bc <- subset(BreastCancer, select = -Id) # for brevity
# Fit a standard classification tree
set.seed(1032) # for reproducibility
tree <- rpart(Class ~ ., data = bc, cp = 0)
# Prune using 1-SE rule (e.g., use `plotcp(tree)` for guidance)
cp <- tree$cptable
cp <- cp[cp[, "nsplit"] == 2L, "CP"]
tree2 <- prune(tree, cp = cp) # tree with three splits
# Default tree-based VIP
vip(tree2)
# Computing permutation importance requires a prediction wrapper. For
# classification, the return value depends on the chosen metric; see
# `?vip::vi_permute` for details.
pfun <- function(object, newdata) {
# Need vector of predicted class probabilities when using log-loss metric
predict(object, newdata = newdata, type = "prob")[, "malignant"]
}
# Permutation-based importance (note that only the predictors that show up
# in the final tree have non-zero importance)
set.seed(1046) # for reproducibility
vi(tree2, method = "permute", nsim = 10, target = "Class", train = bc,
metric = "logloss", pred_wrapper = pfun, reference_class = "malignant")
# Equivalent (but not sorted)
set.seed(1046) # for reproducibility
vi_permute(tree2, nsim = 10, target = "Class", metric = "logloss",
pred_wrapper = pfun, reference_class = "malignant")
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.