R/extract_variables.R

Defines functions lime_extractor live_extractor breakDown_extractor

Documented in breakDown_extractor lime_extractor live_extractor

# Extract relevant variables from local explainer using lime, live or breakDown packages.

#' LIME wrapper
#'
#' @param instance data frame that contains the observation to explain
#' @param target name of the response variable
#' @param training_set data on which the explained `model`` was trained
#' @param model model to be explained
#' @param predict_function predict function for `model` which returns classes.
#' Defaults to `predict`.
#' @param ... additional parameters to `lime::explain` function
#'
#' @return list that consists of
#' \item{explanation_model}{object returned by `lime::explain` function}
#' \item{variables}{character vector of names of variables that contributed to the wrong prediction}
#'
#' @importFrom dplyr filter select n_distinct
#'
#' @export
#'

lime_extractor <- function(instance, target, training_set, model,
                           predict_function = predict, ...) {
  label <- feature <- feature_weight <- NULL

  target_col <- which(colnames(training_set) == target)
  true_label <- instance[, target]
  predicted_label <- predict_function(model, instance[, -target_col])

  lime_object <- lime::lime(training_set[, -target_col], model)
  lime_explainer <- lime::explain(instance[, -target_col], lime_object,
                            n_labels = n_distinct(training_set[, target]),
                            n_features= ncol(training_set) - 1)

  variables <- filter(lime_explainer,
                      (label == true_label & feature_weight < 0) |
                      (label == predicted_label & feature_weight > 0))
  variables <- select(variables, feature)
  variables <- unique(unlist(variables, use.names = FALSE))

  list(explanation_model = lime_explainer,
       variables = as.character(variables))
}


#' LIVE wrapper
#'
#' @param instance data frame that contains the observation to explain
#' @param target name of the response variable
#' @param training_set data on which the explained `model`` was trained
#' @param model model to be explained
#' @param predict_function predict function for `model` which returns classes,
#' defaults to `predict`
#' @param ... additional parameters to `live::sample_locally2` function,
#' `size` argument must be included
#'
#' @return list that consists of
#' \item{explanation_model}{object returned by `lime::fit_explanation2` function}
#' \item{variables}{character vector of names of variables that contributed to the wrong prediction}
#'
#' @importFrom dplyr filter select n_distinct
#' @importFrom live sample_locally2 add_predictions2 fit_explanation2
#' @importFrom mlr getLearnerModel
#' @importFrom stats coef
#'
#' @export
#'

live_extractor <- function(instance, target, training_set, model,
                           predict_function = predict, ...) {
  target_col <- which(colnames(training_set) == target)
  true_label <- as.character(instance[, target])
  predicted_label <- as.character(predict_function(model,
                                                   instance[, -target_col]))
  if(true_label == predicted_label) stop("This function works only for
                                         misclassified instances.")

  if(n_distinct(training_set[, target]) == 2) {
    family <- "binomial"
    expl_model <- "classif.binomial"
  } else {
    family <- "multinomial"
    expl_model <- "classif.multinom"
  }

  neighbourhood <- sample_locally2(training_set, instance, target, method = "lime", ...)
  with_predictions <- add_predictions2(neighbourhood, model, predict_fun = predict_function)
  live_explainer <- fit_explanation2(with_predictions, white_box = expl_model,
                                     response_family = family)
  live_explainer_glm <- getLearnerModel(live_explainer$model)

  if(family == "binomial") {
    if(instance[, target] == levels(training_set[, target])[1]) {
      variables <- (names(coef(live_explainer_glm))[-1])[coef(live_explainer_glm)[-1] < 0]
    } else {
      variables <- (names(coef(live_explainer_glm))[-1])[coef(live_explainer_glm)[-1] > 0]
    }
  } else {
    if(instance[, target] == levels(training_set[, target])[1]) {
      variables <-
        (colnames(coef(live_explainer_glm))[-1])[coef(live_explainer_glm)[predicted_label, -1] < 0]
    } else {
      variables <-
        (colnames(coef(live_explainer_glm))[-1])[coef(live_explainer_glm)[true_label, -1] > 0]
    }
  }

  factors <- colnames(training_set)[sapply(training_set,
                                           function(x)
                                             is.character(x) | is.factor(x))]
  if(length(factors) != 0) {
    factors_selected <- setdiff(variables, colnames(training_set))
    non_factors_selected <- setdiff(variables, factors_selected)
    selected_factors_lgl <- sapply(factors, function(x) any(grepl(x, factors_selected)))
    selected_factors <- names(selected_factors_lgl)[selected_factors_lgl]
    variables <- c(non_factors_selected, selected_factors)
  }

  list(explanation_model = live_explainer,
       variables = as.character(variables))
}


#' breakDown wrapper
#'
#' @param instance data frame that contains the observation to explain
#' @param target name of the response variable
#' @param training_set data on which the explained `model`` was trained
#' @param model model to be explained
#' @param predict_function predict function for `model` which returns classes,
#' defaults to `predict`
#' @param ... additional parameters to `breakDown::broken.default` function,
#' `size` argument must be included
#'
#' @return list that consists of
#' \item{explanation_model}{object returned by `breakDown::broken.default` function}
#' \item{variables}{character vector of names of variables that contributed to the wrong prediction}
#'
#' @importFrom dplyr filter select n_distinct
#' @importFrom stats predict
#'
#' @export
#'

breakDown_extractor <- function(instance, target, training_set, model,
                                predict_function = predict, ...) {
  target_col <- which(colnames(training_set) == target)

  breakdown_explainer <- breakDown::broken(model, instance[, -target_col],
                                           training_set, baseline = 0,
                                           predict.function = predict_function, ...)

  variables <- breakdown_explainer$variable_name
  scores <- breakdown_explainer$contribution[-length(breakdown_explainer$contribution)]
  variables <- variables[scores > 0]
  variables <- variables[variables != "Intercept"]

  list(explanation_model = breakdown_explainer,
       variables = as.character(variables))
}
mstaniak/egalitaRian documentation built on Aug. 26, 2019, 11:11 p.m.