R/importance.R

Defines functions importance

Documented in importance

#' Extract predictor variable importance from a fusion model
#'
#' @description
#' Returns predictor variable (feature) importance of underlying LightGBM models stored in a fusion model file (.fsn) on disk.
#' @param fsn Character. Path to fusion model file (.fsn) generated by \code{\link{train}}.
#' @details Importance metrics are computed via \code{\link[lightgbm]{lgb.importance}}. Three types of measures are returned; "gain" is typically the preferred measure.
#' @return A named list containing \code{detailed} and \code{summary} importance results. The \code{summary} results are most useful, as they return the average importance of each predictor across potentially multiple underlying LightGBM models; i.e. zero ("z"), mean ("m"), or quantile ("q") models. See Examples for suggested plotting of results.
#' @examples
#' # Build a fusion model using RECS microdata
#' # Note that "fusion_model.fsn" will be written to working directory
#' ?recs
#' fusion.vars <- c("electricity", "natural_gas", "aircon")
#' predictor.vars <- names(recs)[2:12]
#' fsn.path <- train(data = recs, y = fusion.vars, x = predictor.vars)
#'
#' # Extract predictor variable importance
#' ximp <- importance(fsn.path)
#'
#' # Plot summary results
#' library(ggplot2)
#' ggplot(ximp$summary, aes(x = x, y = gain)) +
#'   geom_bar(stat = "identity") +
#'   facet_grid(~ y) +
#'   coord_flip()
#'
#' # View detailed results
#' View(ximp$detailed)
#' @export

importance <- function(fsn) {

  stopifnot(exprs = {
    file.exists(fsn) & endsWith(fsn, ".fsn")
  })

  # Temporary directory to unzip .fsn contents to
  td <- tempfile()
  dir.create(td)

  # Names of files within the .fsn object
  fsn.files <- zip::zip_list(fsn)
  pfixes <- sort(unique(dirname(fsn.files$filename)))
  pfixes <- pfixes[pfixes != "."]

  # Unzip all files in .fsn to temporary directory
  zip::unzip(zipfile = fsn, exdir = td)

  # Extract full-detail variable importance metrics
  detail <- lapply(1:length(pfixes), FUN = function(i) {
    mods <- grep(pattern = utils::glob2rx(paste0(pfixes[i], "*.txt")), x = fsn.files$filename, value = TRUE)
    mods %>%
      lapply(FUN = function(m) {
        y <- sub("_[^_]+$", "", basename(m))
        n <- sub(".txt", "", sub(paste0(y, "_"), "", basename(m), fixed = TRUE), fixed = TRUE)
        mod <- lightgbm::lgb.load(filename = file.path(td, m))
        imp <- lightgbm::lgb.importance(mod) %>%
          mutate(y = y,
                 model = n)
        return(imp)
      }) %>%
      data.table::rbindlist()
  }) %>%
    data.table::rbindlist() %>%
    rename_with(tolower) %>%
    rename(x = feature) %>%
    select(y, model, x, everything()) %>%
    as.data.frame()

  # Generate summary of full-detail results
  smry <- detail %>%
    group_by(y, x) %>%
    summarize(across(gain:frequency, mean), .groups = "drop_last") %>%
    mutate(across(gain:frequency, ~ .x / sum(.x))) %>%  # Ensure percentages equal 1
    ungroup() %>%
    arrange(y, -gain)

  # Recommended plotting order for 'x'
  yvars <- unique(detail$y)
  xord <- smry %>%
    filter(!x %in% yvars) %>%
    group_by(x) %>%
    summarize(across(gain, mean)) %>%
    arrange(gain) %>%
    pull(x)

  # Set factor levels for summary 'x' and 'x' variables for suitable plotting
  smry$y <- factor(smry$y, levels = yvars)
  smry$x <- factor(smry$x, levels = c(rev(yvars), xord))

  out <- list(summary = smry, detailed = detail)
  return(out)

}
ummel/fusionModel documentation built on June 1, 2025, 11 p.m.