R/local_conditional_expectations.R

Defines functions local_conditional_expectations

Documented in local_conditional_expectations

#' Local Conditional Expectation Explainer
#'
#' This explainer works for individual observations.
#' For each observation it calculates Local Conditional Expectation (LCE) profiles for selected variables.
#'
#' @param explainer a model to be explained, preprocessed by function `DALEX::explain()`.
#' @param observations set of observarvation for which profiles are to be calculated
#' @param y true labels for `observations`. If specified then will be added to local conditional expectations plots.
#' @param variable_splits named list of splits for variables, in most cases created with `calculate_variable_splits()`. If NULL then it will be calculated based on validation data avaliable in the `explainer`.
#' @param grid_points number of points for profile. Will be passed to `calculate_variable_splits()`.
#' @param variables names of variables for which profiles shall be calculated. Will be passed to `calculate_variable_splits()`. If NULL then all variables from the validation data will be used.
#'
#'
#' @return An object of the class 'ceteris_paribus_explainer'.
#' A data frame with calculated LCE profiles.
#' @export
#'
#' @examples
#' library("DALEX")
#'  \dontrun{
#' library("randomForest")
#' set.seed(59)
#'
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor +
#'       no.rooms + district, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model,
#'       data = apartments[,2:6], y = apartments$m2.price)
#'
#' new_apartment <- apartments[1, ]
#'
#' cp_rf <- ceteris_paribus(explainer_rf, new_apartment)
#' lce_rf <- local_conditional_expectations(explainer_rf, new_apartment)
#' lce_rf
#'
#' lce_rf <- local_conditional_expectations(explainer_rf, new_apartment, y = new_apartment$m2.price)
#' lce_rf
#' 
#' # Plot LCE
#' sel_vars <- c("surface", "no.rooms")
#' plot(lce_rf, selected_variables = sel_vars)
#' 
#' # Compare ceteris paribus profiles with LCE profiles 
#' plot(cp_rf, selected_variables = sel_vars) + 
#'    ceteris_paribus_layer(lce_rf, selected_variables = sel_vars, color = "red")
#'
#' }


local_conditional_expectations <- function(explainer, observations, y = NULL, variable_splits = NULL, variables = NULL, grid_points = 101) {
  if (!("explainer" %in% class(explainer)))
    stop("The local_conditional_expectations() function requires an object created with explain() function.")
  
  predict_function <- explainer$predict_function
  model <- explainer$model
  
  # if splits are not provided, then will be calculated
  if (is.null(variable_splits)) {
    # need validation data from the explainer
    if (is.null(explainer$data))
      stop("The local_conditional_expectations() function requires explainers created with specified 'data'.")
    # need variables, if not provided, will be extracted from data
    if (is.null(variables))
      variables <- intersect(colnames(explainer$data),
                             colnames(observations))
    
    variable_splits <- calculate_variable_splits(explainer$data, variables = variables, grid_points = grid_points)
  }
  
  # calculate profiles
  profiles <- calculate_profiles_lce(observations, variable_splits, model, dataset = explainer$data, predict_function)
  profiles$`_label_` <- explainer$label
  
  # add points of interests
  observations$`_yhat_` <- predict_function(model, observations)
  if (!is.null(y)) observations$`_y_` <- y
  observations$`_label_` <- explainer$label
  
  # prepare final object
  attr(profiles, "observations") <- observations
  class(profiles) = c("ceteris_paribus_explainer", "data.frame")
  profiles
}

Try the ceterisParibus package in your browser

Any scripts or data that you put into this service are public.

ceterisParibus documentation built on March 31, 2020, 5:22 p.m.