R/ceteris_paribus.R

Defines functions individual_variable_profile individual_variable_profile.explainer individual_variable_profile.default

Documented in individual_variable_profile individual_variable_profile.default individual_variable_profile.explainer

#' Calculates Individual Variable Profiles aka Ceteris Paribus Profiles
#'
#' This explainer works for individual observations.
#' For each observation it calculates Individual Variable Profiles for selected variables.
#' For this reason it is also called 'Local Profile Plot'.
#'
#' @param x a model to be explained, or an explainer created with function `DALEX::explain()`.
#' @param data validation dataset, will be extracted from `x` if it's an explainer
#' @param predict_function predict function, will be extracted from `x` if it's an explainer
#' @param new_observation a new observation with columns that corresponds to variables used in the model
#' @param y true labels for `new_observation`. If specified then will be added to ceteris paribus plots.
#' @param variables names of variables for which profiles shall be calculated. Will be passed to `calculate_variable_splits()`. If NULL then all variables from the validation data will be used.
#' @param ... other parameters
#' @param variable_splits named list of splits for variables, in most cases created with `calculate_variable_splits()`. If NULL then it will be calculated based on validation data avaliable in the `explainer`.
#' @param grid_points number of points for profile. Will be passed to `calculate_variable_splits()`.
#' @param label name of the model. By default it's extracted from the 'class' attribute of the model
#'
#'
#' @return An object of the class 'ceteris_paribus_explainer'.
#' It's a data frame with calculated average responses.
#' @export
#'
#' @examples
#' library("DALEX2")
#' library("ceterisParibus2")
#'  \dontrun{
#' library("randomForest")
#' set.seed(59)
#'
#' apartments_rf <- randomForest(m2.price ~ construction.year + surface + floor +
#'                                 no.rooms + district, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf,
#'                         data = apartments_test[,2:6], y = apartments_test$m2.price)
#'
#' my_apartment <- apartments_test[1, ]
#'
#' lp_rf <- individual_variable_profile(explainer_rf, my_apartment)
#' head(lp_rf)
#'
#' plot(lp_rf)
#'
#' # --------
#' # multiclass
#'
#' HR_rf <- randomForest(status ~ . , data = HR)
#' explainer_rf <- explain(HR_rf, data = HRTest, y = HRTest)
#'
#' my_HR <- HRTest[1, ]
#'
#' lp_rf <- individual_variable_profile(explainer_rf,  my_HR)
#' head(lp_rf)
#'
#' plot(lp_rf, color = "_label_")
#'
#' }
#' @export
#' @rdname local_profile
individual_variable_profile <- function(x, ...)
  UseMethod("individual_variable_profile")

#' @export
#' @rdname local_profile
individual_variable_profile.explainer <- function(x, new_observation, y = NULL, variables = NULL,
                                    variable_splits = NULL, grid_points = 101,
                                    ...) {
  # extracts model, data and predict function from the explainer
  model <- x$model
  data <- x$data
  predict_function <- x$predict_function
  label <- x$label

  individual_variable_profile.default(model, data, predict_function,
                             new_observation = new_observation,
                             label = label,
                             variables = variables,
                             grid_points = grid_points,
                             y = y,
                             ...)
}


#' @export
#' @rdname local_profile
individual_variable_profile.default <- function(x, data, predict_function = predict,
                                      new_observation, y = NULL, variables = NULL,
                                      variable_splits = NULL,
                                      grid_points = 101,
                                      label = class(x)[1], ...) {
  # here one can add model and data and new observation
  # just in case only some variables are specified
  # this will work only for data.frames
  if ("data.frame" %in% class(data)) {
    common_variables <- intersect(colnames(new_observation), colnames(data))
    new_observation <- new_observation[, common_variables, drop = FALSE]
    data <- data[,common_variables, drop = FALSE]
  }
  p <- ncol(data)

  # calculate splits
  # if splits are not provided, then will be calculated
  if (is.null(variable_splits)) {
    # need validation data from the explainer
    if (is.null(data))
      stop("The individual_variable_profile() function requires explainers created with specified 'data'.")
    # need variables, if not provided, will be extracted from data
    if (is.null(variables))
      variables <- colnames(data)

    variable_splits <- calculate_variable_split(data, variables = variables, grid_points = grid_points)
  }

  # calculate profiles
  profiles <- calculate_variable_profile(new_observation,
                                 variable_splits, x, predict_function)

  # if there is more then one collumn with `_yhat_`
  # then we need to convert it to a single collumn
  col_yhat <- grep(colnames(profiles), pattern = "^_yhat_")
  if (length(col_yhat) == 1) {
    profiles$`_label_` <- label

    # add points of interests
    new_observation$`_yhat_` <- predict_function(x, new_observation)
    new_observation$`_label_` <- label
    new_observation$`_ids_` <- 1:nrow(new_observation)
    if (!is.null(y)) new_observation$`_y_` <- y
  } else {
    # we need to recreate _yhat_ and create proper labels
    new_profiles <- profiles[rep(1:nrow(profiles), times = length(col_yhat)), -col_yhat]
    new_profiles$`_yhat_` <- unlist(c(profiles[,col_yhat]))
    stripped_names <- gsub(colnames(profiles)[col_yhat], pattern = "_yhat_", replacement = "")
    new_profiles$`_label_` <- paste0(label, rep(stripped_names, each = nrow(profiles)))
    profiles <- new_profiles

    # add points of interests
    new_observation_ext <- new_observation[rep(1:nrow(new_observation), times = length(col_yhat)),]
    predict_obs <- predict_function(x, new_observation)
    new_observation_ext$`_yhat_` <- unlist(c(predict_obs))
    new_observation_ext$`_label_` <- paste0(label, rep(stripped_names, each = nrow(new_observation)))
    new_observation_ext$`_ids_` <- rep(1:nrow(new_observation), each = nrow(new_observation))

    # add y
    if (!is.null(y)) new_observation_ext$`_y_` <- rep(y, times = length(col_yhat))

    new_observation <- new_observation_ext
  }

  # prepare final object
  attr(profiles, "observations") <- new_observation
  class(profiles) = c("individual_variable_profile_explainer", "data.frame")
  profiles
}
pbiecek/ceterisParibus2 documentation built on Sept. 16, 2019, 6:26 p.m.