Nothing
Variable <- NULL
Importance <- NULL
#' Variable importance plots
#'
#' Plot variable importance scores for the predictors in a model.
#'
#' @param object A fitted model (e.g., of class
#' [randomForest][randomForest::randomForest] object) or a [vi][vip::vi] object.
#'
#' @param num_features Integer specifying the number of variable importance
#' scores to plot. Default is `10`.
#'
#' @param geom Character string specifying which type of plot to construct.
#' The currently available options are described below.
#'
#' * `geom = "col"` uses [geom_col][ggplot2::geom_col] to construct a bar chart
#' of the variable importance scores.
#'
#' * `geom = "point"` uses [geom_point][ggplot2::geom_point] to construct a
#' Cleveland dot plot of the variable importance scores.
#'
#' * `geom = "boxplot"` uses [geom_boxplot][ggplot2::geom_boxplot] to
#' construct a boxplot plot of the variable importance scores. This option can
#' only for the permutation-based importance method with \code{nsim > 1} and
#' `keep = TRUE`; see [vi_permute][vip::vi_permute] for details.
#'
#' * `geom = "violin"` uses [geom_violin][ggplot2::geom_violin] to
#' construct a violin plot of the variable importance scores. This option can
#' only for the permutation-based importance method with \code{nsim > 1} and
#' `keep = TRUE`; see [vi_permute][vip::vi_permute] for details.
#'
#' @param mapping Set of aesthetic mappings created by
#' [aes][ggplot2::aes]-related functions and/or tidy eval helpers. See example
#' usage below.
#'
#' @param aesthetics List specifying additional arguments passed on to
#' [layer][ggplot2::layer]. These are often aesthetics, used to set an aesthetic
#' to a fixed value, like`colour = "red"` or `size = 3`. See example usage
#' below.
#'
#' @param horizontal Logical indicating whether or not to plot the importance
#' scores on the x-axis (`TRUE`). Default is `TRUE`.
#'
#' @param all_permutations Logical indicating whether or not to plot all
#' permutation scores along with the average. Default is `FALSE`. (Only used for
#' permutation scores when `nsim > 1`.)
#'
#' @param jitter Logical indicating whether or not to jitter the raw permutation
#' scores. Default is `FALSE`. (Only used when `all_permutations = TRUE`.)
#'
#' @param include_type Logical indicating whether or not to include the type of
#' variable importance computed in the axis label. Default is `FALSE`.
#'
#' @param ... Additional optional arguments to be passed on to [vi][vip::vi].
#'
#' @importFrom stats reorder
#'
#' @rdname vip
#'
#' @export
#'
#' @examples
#' #
#' # A projection pursuit regression example using permutation-based importance
#' #
#'
#' # Load the sample data
#' data(mtcars)
#'
#' # Fit a projection pursuit regression model
#' model <- ppr(mpg ~ ., data = mtcars, nterms = 1)
#'
#' # Construct variable importance plot (permutation importance, in this case)
#' set.seed(825) # for reproducibility
#' pfun <- function(object, newdata) predict(object, newdata = newdata)
#' vip(model, method = "permute", train = mtcars, target = "mpg", nsim = 10,
#' metric = "rmse", pred_wrapper = pfun)
#'
#' # Better yet, store the variable importance scores and then plot
#' set.seed(825) # for reproducibility
#' vis <- vi(model, method = "permute", train = mtcars, target = "mpg",
#' nsim = 10, metric = "rmse", pred_wrapper = pfun)
#' vip(vis, geom = "point", horiz = FALSE)
#' vip(vis, geom = "point", horiz = FALSE, aesthetics = list(size = 3))
#'
#' # Plot unaggregated permutation scores (boxplot colored by feature)
#' library(ggplot2) # for `aes()`-related functions and tidy eval helpers
#' vip(vis, geom = "boxplot", all_permutations = TRUE, jitter = TRUE,
#' #mapping = aes_string(fill = "Variable"), # for ggplot2 (< 3.0.0)
#' mapping = aes(fill = .data[["Variable"]]), # for ggplot2 (>= 3.0.0)
#' aesthetics = list(color = "grey35", size = 0.8))
#'
#' #
#' # A binary classification example
#' #
#' \dontrun{
#' library(rpart) # for classification and regression trees
#'
#' # Load Wisconsin breast cancer data; see ?mlbench::BreastCancer for details
#' data(BreastCancer, package = "mlbench")
#' bc <- subset(BreastCancer, select = -Id) # for brevity
#'
#' # Fit a standard classification tree
#' set.seed(1032) # for reproducibility
#' tree <- rpart(Class ~ ., data = bc, cp = 0)
#'
#' # Prune using 1-SE rule (e.g., use `plotcp(tree)` for guidance)
#' cp <- tree$cptable
#' cp <- cp[cp[, "nsplit"] == 2L, "CP"]
#' tree2 <- prune(tree, cp = cp) # tree with three splits
#'
#' # Default tree-based VIP
#' vip(tree2)
#'
#' # Computing permutation importance requires a prediction wrapper. For
#' # classification, the return value depends on the chosen metric; see
#' # `?vip::vi_permute` for details.
#' pfun <- function(object, newdata) {
#' # Need vector of predicted class probabilities when using log-loss metric
#' predict(object, newdata = newdata, type = "prob")[, "malignant"]
#' }
#'
#' # Permutation-based importance (note that only the predictors that show up
#' # in the final tree have non-zero importance)
#' set.seed(1046) # for reproducibility
#' vip(tree2, method = "permute", nsim = 10, target = "Class",
#' metric = "logloss", pred_wrapper = pfun, reference_class = "malignant")
#' }
vip <- function(object, ...) {
UseMethod("vip")
}
#' @rdname vip
#'
#' @export
vip.default <- function(
object,
num_features = 10L,
geom = c("col", "point", "boxplot", "violin"),
mapping = NULL,
aesthetics = list(),
horizontal = TRUE,
all_permutations = FALSE,
jitter = FALSE,
include_type = FALSE,
...
) {
# Character string specifying which type of plot to construct
geom <- match.arg(geom, several.ok = FALSE)
# Extract or compute importance scores
imp <- if (inherits(object, what = "vi")) {
object
} else {
vi(object = object, ...) # compute variable importance scores
}
# Character string specifying the type of VI that was computed
vi_type <- attr(imp, which = "type") # subsetting removes this attribute!
# Integer specifying the number of features to include in the plot
num_features <- as.integer(num_features)[1L] # make sure num_features is a single integer
if (num_features > nrow(imp) || num_features < 1L) {
num_features <- nrow(imp)
}
imp <- sort_importance_scores(imp, decreasing = TRUE) # make sure these are sorted first!
imp <- imp[seq_len(num_features), ] # only retain num_features variable importance scores
# x.string <- "reorder(Variable, Importance)"
# Clean up raw scores for permutation-based VI scores
if (!is.null(attr(imp, which = "raw_scores"))) {
raw_scores <- as.data.frame(attr(imp, which = "raw_scores"))
raw_scores$Variable <- rownames(raw_scores)
raw_scores <- stats::reshape(
data = raw_scores,
varying = (1L:(ncol(raw_scores) - 1)),
v.names = "Importance",
direction = "long",
sep = "_"
)
raw_scores <- raw_scores[raw_scores$Variable %in% imp$Variable, ]
}
# Initialize plot
# p <- ggplot2::ggplot(imp, ggplot2::aes_string(x = x.string, y = "Importance"))
p <- ggplot2::ggplot(imp, ggplot2::aes(
x = reorder(Variable, Importance),
y = Importance
))
# Construct a barplot
if (geom == "col") {
p <- p + do.call(
what = ggplot2::geom_col,
args = c(list(mapping = mapping), aesthetics)
)
}
# Construct a (Cleveland) dotplot
if (geom == "point") {
p <- p + do.call(
what = ggplot2::geom_point,
args = c(list(mapping = mapping), aesthetics)
)
}
# Construct a boxplot
if (geom == "boxplot") {
if (!is.null(attr(imp, which = "raw_scores"))) {
p <- p + do.call(
what = ggplot2::geom_boxplot,
args = c(list(data = raw_scores, mapping = mapping), aesthetics)
)
} else {
stop("To construct boxplots for permutation-based importance scores you ",
"must specify `keep = TRUE` in the call `vi()` or `vi_permute()`. ",
"Additionally, you also need to set `nsim >= 2`.",
call. = FALSE)
}
}
# Construct a violin plot
if (geom == "violin") {
if (!is.null(attr(imp, which = "raw_scores"))) {
p <- p + do.call(
what = ggplot2::geom_violin,
args = c(list(data = raw_scores, mapping = mapping), aesthetics)
)
} else {
stop("To construct violin plots for permutation-based importance scores ",
"you must specify `keep = TRUE` in the call `vi()` or ",
"`vi_permute()`. Additionally, you also need to set `nsim >= 2`.",
call. = FALSE)
}
}
# Plot raw permutation scores (if available and requested)
if (!is.null(attr(imp, which = "raw_scores")) && all_permutations) {
p <- if (jitter) {
p + ggplot2::geom_jitter(data = raw_scores)
} else {
p + ggplot2::geom_point(data = raw_scores)
}
}
# Add labels, titles, etc.
p <- p + ggplot2::theme(legend.position = "none")
p <- p + ggplot2::xlab("") # no need for x-axis label
if (horizontal) {
p <- p + ggplot2::coord_flip()
}
if (isTRUE(include_type)) {
p + ggplot2::ylab(paste0("Importance (", vi_type, ")"))
} else {
p + ggplot2::ylab("Importance")
}
}
#' @rdname vip
#'
#' @export
vip.model_fit <- function(object, ...) {
vip(parsnip::extract_fit_engine(object), ...)
}
#' @rdname vip
#'
#' @export
vip.workflow <- function(object, ...) {
vip(workflows::extract_fit_engine(object), ...)
}
#' @rdname vip
#'
#' @export
vip.WrappedModel <- function(object, ...) { # package: mlr
vip(object$learner.model, ...)
}
#' @rdname vip
#'
#' @export
vip.Learner <- function(object, ...) { # package: mlr3
if (is.null(object$model)) {
stop("No fitted model found. Did you forget to call ",
deparse(substitute(object)), "$train()?",
call. = FALSE)
}
vip(object$model, ...)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.