R/feature.selection.R

Defines functions feature.selection

Documented in feature.selection

#' @title Selects the top features with highest weighted mean shap values based on the
#'        specified criteria
#' @description This function specifies the top features and prepares the data
#'              for plotting SHAP contributions for each row, or summary of absolute
#'              SHAP contributions for each feature.
#' @param shapley shapley object
#' @param method Character. The column name in \code{summaryShaps} used
#'                           for feature selection. Default is \code{"mean"}, which
#'                           selects important features which have weighted mean shap
#'                           ratio (WMSHAP) higher than the specified cutoff. Other
#'                           alternative is "lowerCI", which selects features which
#'                           their lower bound of confidence interval is higher than
#'                           the cutoff.
#' @param cutoff numeric, specifying the cutoff for the method used for selecting
#'               the top features. the default is zero, which means that all
#'               features with the "method" criteria above zero will be selected.
#' @param top_n_features integer. if specified, the top n features with the
#'                       highest weighted SHAP values will be selected, overrullung
#'                       the 'cutoff' and 'method' arguments.
#' @param features character vector, specifying the feature to be plotted.
#' @author E. F. Haghish
#' @return normalized numeric vector


feature.selection <- function(shapley,
                              method = "mean",
                              cutoff=0.0,
                              top_n_features=NULL,
                              features = NULL) {

  # variables
  # ============================================================
  DATA <- shapley$contributionPlot$data
  if (is.null(features)) features <- as.character(shapley$summaryShaps$feature)

  # Select the features that meet the criteria
  # ============================================================
  if (length(shapley[["ids"]]) >= 1) {
    if (!is.null(top_n_features)) {
      shapley$summaryShaps <- shapley$summaryShaps[order(
        shapley$summaryShaps$mean, decreasing = TRUE), ]
      shapley$summaryShaps <- shapley$summaryShaps[1:top_n_features, ]

      shapley$contributionPlot$data <- DATA[
        DATA$feature %in% features, ]
    }
    else if (method == "mean") {
      shapley$summaryShaps <- shapley$summaryShaps[shapley$summaryShaps$mean > cutoff, ]
      shapley$contributionPlot$data <- DATA[DATA$feature %in% features, ]

    } else if (method == "lowerCI") {
      if (length(shapley[["ids"]]) == 1) stop("shapley object includes a single model and lowerCI cannot be used")
      shapley$summaryShaps <- shapley$summaryShaps[shapley$summaryShaps$lowerCI > cutoff, ]
      shapley$contributionPlot$data <- DATA[DATA$feature %in% features, ]

    } else {
      stop("method must be one of 'mean' or 'lowerCI'")
    }
  }
  else (stop("at least 1 model must be included in the shapley object"))

  # Sort the features based on their mean SHAP values
  # ============================================================
  index <- order(- shapley$summaryShaps$mean)
  features <- features[index]
  mean <- shapley$summaryShaps$mean[index]

  return(list(shapley = shapley,
              features = features,
              mean = mean))

}

Try the shapley package in your browser

Any scripts or data that you put into this service are public.

shapley documentation built on April 12, 2025, 2:16 a.m.