R/CounterfactualMethod.R

#' Base class for Counterfactual Explanation Methods
#' 
#' @description 
#' Abstract base class for counterfactual explanation methods.
#' 
#' @section Inheritance:
#' Child classes: \link{CounterfactualMethodClassif}, \link{CounterfactualMethodRegr}
CounterfactualMethod = R6::R6Class("CounterfactualMethod",
  
  public = list(
    
    #' @description Creates a new `CounterfactualMethod` object.
    #' @template predictor
    #' @template lower_upper
    #' @param distance_function (`character(1)` | `function()`)\cr
    #'  Either the name of an already implemented distance function
    #'  (currently 'gower' or 'gower_c') or a function having three arguments:
    #'  `x`, `y`, and `data`. The function should return a `double` matrix with
    #'  `nrow(x)` rows and maximum `nrow(y)` columns.
    #'
    initialize = function(predictor, lower = NULL, upper = NULL, distance_function = NULL) {
      assert_class(predictor, "Predictor")
      assert_numeric(lower, null.ok = TRUE)
      assert_numeric(upper, null.ok = TRUE)
      assert_true(all(names(lower) %in% names(predictor$data$X)))
      assert_true(all(names(upper) %in% names(predictor$data$X)))
      assert_function(distance_function, args = c("x", "y", "data"), ordered = TRUE, null.ok = TRUE)
      
      # If the task could not be derived from the model, then we infer it from the prediction of some training data
      if (predictor$task == "unknown") {
        # Needs to be set to NULL, as the predictor does not infer the task from prediction otherwise
        # See: https://github.com/christophM/iml/blob/master/R/Predictor.R#L141 of commit 409838a.
        # The task is then checked by `CounterfactualMethodRegr` or `CounterfactualMethodClassif`
        predictor$task = NULL
        predictor$predict(predictor$data$X[1:2, ])
      }
      
      private$predictor = predictor
      private$param_set = make_param_set(predictor$data$X, lower, upper)
      private$lower = lower
      private$upper = upper
      private$distance_function = distance_function
    },
    
    #' @description 
    #' Prints a `CounterfactualMethod` object.
    #' The method calls a (private) `$print_parameters()` method which should be implemented by the leaf classes.
    print = function() {
      cat("Counterfactual explanation method: ", class(self)[1], "\n")
      cat("Parameters:\n")
      private$print_parameters()
    }
  ),
  
  private = list(
    predictor = NULL,
    x_interest = NULL,
    param_set = NULL,
    lower = NULL,
    upper = NULL,
    distance_function = NULL,
    method = NULL,
    
    run = function() stop("abstract"),
    
    print_parameters = function() {}
  )
)

Try the counterfactuals package in your browser

Any scripts or data that you put into this service are public.

counterfactuals documentation built on March 31, 2023, 7:17 p.m.