R/WhatIfClassif.R

#' WhatIf for Classification Tasks
#' 
#' @description 
#' 
#' WhatIf returns the `n_counterfactual` most similar observations to `x_interest` from observations in `predictor$data$X` 
#' whose prediction for the `desired_class` is in the `desired_prob` interval.
#' 
#' @details
#' By default, the dissimilarities are computed using Gower's dissimilarity measure (Gower 1971). \cr
#' Only observations whose features values lie between the corresponding values in `lower` and `upper` are considered 
#' counterfactual candidates.
#' 
#' @references 
#' 
#' Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.
#' 
#' Wexler, J., Pushkarna, M., Bolukbasi, T., Wattenberg, M., Viégas, F., & Wilson, J. (2019). The what-if tool: 
#' Interactive probing of machine learning models. IEEE transactions on visualization and computer graphics, 26(1), 56–65.
#' 
#' @examples 
#' if (require("randomForest")) {
#'   # Train a model
#'   rf = randomForest(Species ~ ., data = iris)
#'   # Create a predictor object
#'   predictor = iml::Predictor$new(rf, type = "prob")
#'   # Find counterfactuals for x_interest
#'   wi_classif = WhatIfClassif$new(predictor, n_counterfactuals = 5L)
#'   cfactuals = wi_classif$find_counterfactuals(
#'     x_interest = iris[150L, ], desired_class = "versicolor", desired_prob = c(0.5, 1)
#'   )
#'   # Print the results
#'   cfactuals$data
#' }
#' 
#' @export
WhatIfClassif = R6::R6Class("WhatIfClassif", inherit = CounterfactualMethodClassif,
  
  public = list(
    #' @description Create a new WhatIfClassif object.
    #' @template predictor
    #' @param n_counterfactuals (`integerish(1)`)\cr
    #'   The number of counterfactuals to return. Default is `1L`.
    #' @template lower_upper
    #' @param distance_function (`function()` | `'gower'` | `'gower_c'`)\cr 
    #'  The distance function used to compute the distances between `x_interest`
    #'  and the training data points for finding `x_nn`.
    #'  Either the name of an already implemented distance function
    #'  ('gower' or 'gower_c') or a function.
    #'  If set to 'gower' (default), then Gower's distance (Gower 1971) is used;
    #'  if set to 'gower_c', a C-based more efficient version of Gower's distance is used.
    #'  A function must have three arguments  `x`, `y`, and `data` and should
    #'  return a `double` matrix with `nrow(x)` rows and maximum `nrow(y)` columns.
    initialize = function(predictor, n_counterfactuals = 1L, lower = NULL, upper = NULL, distance_function = "gower") {
      
      if (is.character(distance_function)) {
        if (distance_function == "gower") {
          distance_function = gower_dist
        } else if (distance_function == "gower_c") {
          if (!requireNamespace("gower", quietly = TRUE)) {
            stop("Package 'gower' needed for distance_function = 'gower_c'. Please install it.", call. = FALSE)
          }
          distance_function = function(x, y, data) {
            gower_dist_c(x, y, data, k = n_counterfactuals, idx = TRUE)
          }
          class(distance_function) = class(gower_dist_c)
        }
      }
      
      super$initialize(predictor, lower, upper, distance_function)
     
      assert_integerish(n_counterfactuals, lower = 1L, any.missing = FALSE, len = 1L)
      private$n_counterfactuals = n_counterfactuals
      X_search = private$predictor$data$X
      if (!is.null(lower)) {
        X_search = X_search[Reduce(`&`, Map(`>=`, X_search[, names(lower), with = FALSE], lower))]
      }
      if (!is.null(upper)) {
        X_search = X_search[Reduce(`&`, Map(`<=`, X_search[, names(upper), with = FALSE], upper))]
      }
      if (nrow(X_search) < n_counterfactuals) {
        warning(sprintf("Could only find %s candidate(s) with feature values between `lower` and `upper`.", nrow(X_search)))
      }
      private$X_search = X_search
    }
  ),
  
  private = list(
    n_counterfactuals = NULL,
    X_search = NULL,
    
    run = function() {
      pred_column = private$get_pred_column()
      whatif_algo(
        predictor = private$predictor, 
        n_cfactuals = private$n_counterfactuals, 
        x_interest = private$x_interest, 
        pred_column = pred_column, 
        desired_y_hat_range = private$desired_prob,
        X_search = private$X_search,
        distance_function = private$distance_function
      )
    },
    
    print_parameters = function() {
      cat(" - n_counterfactuals: ", private$n_counterfactuals)
    }
  )
)

Try the counterfactuals package in your browser

Any scripts or data that you put into this service are public.

counterfactuals documentation built on March 31, 2023, 7:17 p.m.