R/vi_boots.R

Defines functions vi_single_boot vi_boots

Documented in vi_boots

#' Fit and estimate variable importance from a workflow using many bootstrap resamples.
#'
#' Generate variable importances from a tidymodel workflow using bootstrap resampling.
#' `vi_boots()` generates `n` bootstrap resamples, fits a model to each (creating
#' `n` models), then creates `n` estimates of variable importance for each variable
#' in the model.
#'
#' @details Since `vi_boots()` fits a new model to each resample, the
#'  argument `workflow` must not yet be fit. Any tuned hyperparameters must be
#'  finalized prior to calling `vi_boots()`.
#'
#' @return A tibble with a column indicating each variable in the model and a
#'  nested list of variable importances for each variable. The shape of the list
#'  may vary by model type. For example, linear models return two nested columns:
#'  the absolute value of each variable's importance and the sign (POS/NEG),
#'  whereas tree-based models return a single nested column of variable importance.
#'  Similarly, the number of nested rows may vary by model type as some models
#'  may not utilize every possible predictor.
#'
#' @param workflow An un-fitted workflow object.
#' @param training_data A tibble or dataframe of data to be resampled and used for training.
#' @param n An integer for the number of bootstrap resampled models that will be created.
#' @param verbose A logical. Defaults to `FALSE`. If set to `TRUE`, prints progress
#'   of training to console.
#' @param ... Additional params passed to `rsample::bootstraps()`.
#'
#' @export
#'
#' @importFrom rlang warn
#' @importFrom rsample bootstraps
#' @importFrom purrr map_dfr
#' @importFrom dplyr rename_with
#' @importFrom tidyr nest
#'
#' @examples
#' \dontrun{
#' library(tidymodels)
#'
#' # setup a workflow without fitting
#' wf <-
#'   workflow() %>%
#'   add_recipe(recipe(qsec ~ wt, data = mtcars)) %>%
#'   add_model(linear_reg())
#'
#' # fit and estimate variable importance from 125 bootstrap resampled models
#' set.seed(123)
#' wf %>%
#'   vi_boots(n = 2000, training_data = mtcars)
#' }
vi_boots <- function(workflow,
                     n = 2000,
                     training_data,
                     verbose = FALSE,
                     ...) {

  # check arguments
  assert_workflow(workflow)
  assert_n(n)
  assert_pred_data(workflow, training_data, "training")

  # warn if low n
  if (n < 2000) {

    rlang::warn(
      paste0("At least 2000 resamples recommended for stable results.")
    )

  }

  # create resamples from training set
  training_boots <-
    rsample::bootstraps(
      training_data,
      times = n,
      ...
    )

  # map sequence of indices to `vi_single_boot()`
  # returns a variable + importance for each model (number of cols may vary by model type)
  bootstrap_vi <-
    purrr::map_dfr(
      seq(1, n),
      ~vi_single_boot(
        workflow = workflow,
        boot_splits = training_boots,
        verbose = verbose,
        index = .x
      )
    )

  # rename cols
  bootstrap_vi <- dplyr::rename_with(bootstrap_vi, tolower)
  bootstrap_vi <- dplyr::rename(bootstrap_vi, model.importance = importance)

  # return a nested tibble
  bootstrap_vi <- tidyr::nest(bootstrap_vi, .importances = -variable)

  return(bootstrap_vi)

}

# -------------------------------internals--------------------------------------

#' Fit a model and get the variable importance based on a single bootstrap resample
#'
#' @param workflow passed from `vi_boots()`
#' @param boot_splits passed from `vi_boots()`
#' @param verbose passed from `vi_boots()`
#' @param index passed from `vi_boots()`
#'
#' @importFrom rsample training
#' @importFrom generics fit
#' @importFrom vip vi
#' @importFrom workflows extract_fit_engine
#'
#' @noRd
#'
vi_single_boot <- function(workflow,
                           boot_splits,
                           verbose,
                           index) {

  # get training data from bootstrap resample split
  boot_train <-
    rsample::training(
      boot_splits$splits[[index]]
    )

  # fit workflow to to the training data
  model <- generics::fit(workflow, boot_train)

  # get the variable importance from the model
  vi_boot <- vip::vi(workflows::extract_fit_engine(model))

  # add model name
  vi_boot <- dplyr::mutate(vi_boot, model = paste0(".importance_", index))

  # print progress when verbose is set to TRUE
  verbose_print(
    verbose = verbose,
    index = index,
    total = nrow(boot_splits)
  )

  return(vi_boot)

}

Try the workboots package in your browser

Any scripts or data that you put into this service are public.

workboots documentation built on May 16, 2022, 9:05 a.m.