R/vi.R

Defines functions vi.default vi

Documented in vi vi.default

#' Variable importance
#'
#' Compute variable importance scores for the predictors in a model.
#'
#' @param object A fitted model object (e.g., a
#' [randomForest][randomForest::randomForest] object) or an object that inherits
#' from class `"vi"`.
#'
#' @param method Character string specifying the type of variable importance
#' (VI) to compute. Current options are:
#'
#' * `"model"` (the default), for model-specific VI scores (see
#' [vi_model][vip::vi_model] for details).
#'
#' * `"firm"`, for variance-based VI scores (see [vi_firm][vip::vi_firm] for
#' details).
#'
#' * `"permute"`, for permutation-based VI scores (see
#' [vi_permute][vip::vi_permute] for details).
#'
#' * `"shap"`, for Shapley-based VI scores (see [vi_shap][vip::vi_shap] for
#' details).
#'
#' @param feature_names Character string giving the names of the predictor
#' variables (i.e., features) of interest.
#'
#' @param abbreviate_feature_names Integer specifying the length at which to
#' abbreviate feature names. Default is `NULL` which results in no
#' abbreviation (i.e., the full name of each feature will be printed).
#'
#' @param sort Logical indicating whether or not to order the sort the variable
#' importance scores. Default is `TRUE`.
#'
#' @param decreasing Logical indicating whether or not the variable importance
#' scores should be sorted in descending (`TRUE`) or ascending
#' (\code{FALSE}) order of importance. Default is `TRUE`.
#'
#' @param scale Logical indicating whether or not to scale the variable
#' importance scores so that the largest is 100. Default is `FALSE`.
#'
#' @param rank Logical indicating whether or not to rank the variable
#' importance scores (i.e., convert to integer ranks). Default is `FALSE`.
#' Potentially useful when comparing variable importance scores across different
#' models using different methods.
#'
#' @param ... Additional optional arguments to be passed on to
#' [vi_model][vip::vi_model], [vi_firm][vip::vi_firm],
#' [vi_permute][vip::vi_permute], or [vi_shap][vip::vi_shap]; see their
#' respective help pages for details.
#'
#' @return A tidy data frame (i.e., a [tibble][tibble::tibble] object) with two
#' columns:
#'
#' * `Variable` - the corresponding feature name;
#' * `Importance` - the associated importance, computed as the average change in
#' performance after a random permutation (or permutations, if `nsim > 1`) of
#' the feature in question.
#'
#' For [lm][stats::lm]/[glm][stats::glm]-like objects, whenever
#' `method = "model"`, the sign (i.e., POS/NEG) of the original coefficient is
#' also included in a column called `Sign`.
#'
#' If `method = "permute"` and `nsim > 1`, then an additional column (`StDev`)
#' containing the standard deviation of the individual permutation scores for
#' each feature is also returned; this helps assess the stability/variation of
#' the individual permutation importance for each feature.
#'
#' @rdname vi
#'
#' @export
#'
#' @examples
#' #
#' # A projection pursuit regression example
#' #
#'
#' # Load the sample data
#' data(mtcars)
#'
#' # Fit a projection pursuit regression model
#' mtcars.ppr <- ppr(mpg ~ ., data = mtcars, nterms = 1)
#'
#' # Prediction wrapper that tells vi() how to obtain new predictions from your
#' # fitted model
#' pfun <- function(object, newdata) predict(object, newdata = newdata)
#'
#' # Compute permutation-based variable importance scores
#' set.seed(1434)  # for reproducibility
#' (vis <- vi(mtcars.ppr, method = "permute", target = "mpg", nsim = 10,
#'            metric = "rmse", pred_wrapper = pfun, train = mtcars))
#'
#' # Plot variable importance scores
#' vip(vis, include_type = TRUE, all_permutations = TRUE,
#'     geom = "point", aesthetics = list(color = "forestgreen", size = 3))
#'
#' #
#' # A binary classification example
#' #
#' \dontrun{
#' library(rpart)  # for classification and regression trees
#'
#' # Load Wisconsin breast cancer data; see ?mlbench::BreastCancer for details
#' data(BreastCancer, package = "mlbench")
#' bc <- subset(BreastCancer, select = -Id)  # for brevity
#'
#' # Fit a standard classification tree
#' set.seed(1032)  # for reproducibility
#' tree <- rpart(Class ~ ., data = bc, cp = 0)
#'
#' # Prune using 1-SE rule (e.g., use `plotcp(tree)` for guidance)
#' cp <- tree$cptable
#' cp <- cp[cp[, "nsplit"] == 2L, "CP"]
#' tree2 <- prune(tree, cp = cp)  # tree with three splits
#'
#' # Default tree-based VIP
#' vip(tree2)
#'
#' # Computing permutation importance requires a prediction wrapper. For
#' # classification, the return value depends on the chosen metric; see
#' # `?vip::vi_permute` for details.
#' pfun <- function(object, newdata) {
#'   # Need vector of predicted class probabilities when using  log-loss metric
#'   predict(object, newdata = newdata, type = "prob")[, "malignant"]
#' }
#'
#' # Permutation-based importance (note that only the predictors that show up
#' # in the final tree have non-zero importance)
#' set.seed(1046)  # for reproducibility
#' vi(tree2, method = "permute", nsim = 10, target = "Class", train = bc,
#'    metric = "logloss", pred_wrapper = pfun, reference_class = "malignant")
#'
#' # Equivalent (but not sorted)
#' set.seed(1046)  # for reproducibility
#' vi_permute(tree2, nsim = 10, target = "Class", metric = "logloss",
#'            pred_wrapper = pfun, reference_class = "malignant")
#' }
vi <- function(object, ...) {
  UseMethod("vi")
}


#' @rdname vi
#'
#' @export
vi.default <- function(
  object,
  method = c("model", "firm", "permute", "shap"),
  feature_names = NULL,
  abbreviate_feature_names = NULL,
  sort = TRUE,
  decreasing = TRUE,
  scale = FALSE,
  rank = FALSE,
  ...
) {

  # Construct VI scores
  method <- match.arg(method)
  if (method == "firm") {
    if (is.null(feature_names)) {
      feature_names <- get_feature_names(object)
    }
  }

  # Construct tibble of VI scores
  tib <- switch(method,
    "model" = vi_model(object, ...),
    "firm" = vi_firm(object, feature_names = feature_names, ...),
    "permute" = vi_permute(object, feature_names = feature_names, ...),
    vi_shap(object, feature_names = feature_names, ...)
  )

  # Save attribute
  vi_type <- attr(tib, which = "type")

  # Remove rows with NA
  tib <- stats::na.omit(tib)

  # Sort VI scores (if requested)
  if (sort) {
    tib <- sort_importance_scores(tib, decreasing = decreasing)
  }

  # Abbreviate feature names (if requested)
  if (!is.null(abbreviate_feature_names)) {
    tib <- abbreviate_names(tib, minlength = abbreviate_feature_names)
  }

  # Scale VI scores so that largest is 100
  if (scale) {
    tib$Importance <- tib$Importance / max(tib$Importance) * 100
  }

  # Rank VI scores (i.e., convert to integer ranks)
  if (rank) {
    tib$Importance <- rev(rank(tib$Importance, ties.method = "average"))
  }

  # Restore attribute
  attr(tib, which = "type") <- vi_type

  # Add "vi" class
  if (!inherits(tib, what = "vi")) {  # In case class gets stripped?
    class(tib) <- c("vi", class(tib))
  }

  # Return results
  tib

}

Try the vip package in your browser

Any scripts or data that you put into this service are public.

vip documentation built on Aug. 21, 2023, 5:12 p.m.