R/var_imp.R

Defines functions var_imp.train var_imp.maboost var_imp.randomForest var_imp.default var_imp

Documented in var_imp var_imp.default var_imp.maboost var_imp.randomForest var_imp.train

#' Variable Importance
#'
#' Methods to calculate variable importance for different classifiers.
#' @inheritParams prediction
#'
#' @author Derek Chiu
#' @export
#' @examples
#' data(hgsc)
#' class <- attr(hgsc, "class.true")
#' mod <- classification(hgsc, class, "xgboost")
#' var_imp(mod)
var_imp <- function(mod, ...) {
  UseMethod("var_imp")
}

#' @rdname var_imp
#' @export
var_imp.default <- function(mod, ...) {
  tryCatch(
    vip::vi(mod),
    error = function(e) NULL
  )
}

#' @rdname var_imp
#' @export
var_imp.randomForest <- function(mod, ...) {
  var_imp.default(mod, ...)
}

#' @rdname var_imp
#' @export
var_imp.maboost <- function(mod, ...) {
  loadNamespace("maboost")
  mod %>%
    maboost::varplot.maboost(plot.it = FALSE,
                             type = "scores",
                             max.var.show = Inf) %>%
    tibble::enframe(name = "Variable", value = "Importance")
}

#' @rdname var_imp
#' @export
var_imp.train <- function(mod, ...) {
  if (!requireNamespace("caret", quietly = TRUE)) {
    stop("Package \"caret\" is needed. Please install it.",
         call. = FALSE)
  } else if (!requireNamespace("fastshap", quietly = TRUE)) {
    stop("Package \"fastshap\" is needed. Please install it.",
         call. = FALSE)
  } else {
    pfun <- function(object, newdata) {
      caret::predict.train(object, newdata = newdata, type = "prob")[, 1]
    }
    mod %>%
      vip::vi_shap(pred_wrapper = pfun) %>%
      dplyr::arrange(dplyr::desc(.data$Importance))
  }
}
AlineTalhouk/splendid documentation built on Feb. 23, 2024, 9:37 p.m.