R/plot_metrics.R

Defines functions plot_metrics

Documented in plot_metrics

#' Plot of LIME Comparison Metrics
#'
#' Plots the specified comparison metrics versus LIME tuning parameters.
#'
#' @param explanations Explain data frame from the list returned by apply_lime.
#' @param metrics Vector specifying metrics to compute. Default is 'all'. See details for metrics available.
#' @param add_lines Draw lines between tuning parameters with the same gower power.
#' @param rank_legend Specifies whether the legend for rank is treated as 'continuous' or 'discrete'.
#' @param point_size Specifies the size of the points.
#' @param line_size Specifies the size of the lines (if add_lines is TRUE).
#' @param line_alpha Specifies the alpha of the lines (if add_lines is TRUE).
#'
#' @details The metrics available are listed below.
#'
#' \itemize{
#'   \item \code{ave_r2}: Average explainer model R2 value computed over all explanations in the test set.
#'   \item \code{msee}: Mean square explanation error computed over all explanations in the test set.
#'   \item \code{ave_fidelity}: Average fidelity metric (Ribeiro et. al. 2016) computed over all explanations in the test set.
#' }
#'
#' @references Ribeiro, M. T., S. Singh, and C. Guestrin, 2016:
#'   "Why should I trust you?": Explaining the predictions of any classifier.
#'   Proceedings of the 22nd ACM SIGKDD International Conference on
#'   Knowledge Discovery and Data Mining, San Francisco, CA, USA, August
#'   13-17, 2016, 1135–1144.
#'
#' @importFrom ggplot2 element_blank scale_color_gradient scale_colour_grey geom_line
#' @importFrom tidyr gather
#' @importFrom scales seq_gradient_pal
#'
#' @export plot_metrics
#'
#' @examples
#'
#' # Prepare training and testing data
#' x_train = sine_data_train[c("x1", "x2", "x3")]
#' y_train = factor(sine_data_train$y)
#' x_test = sine_data_test[1:5, c("x1", "x2", "x3")]
#' 
#' # Fit a random forest model
#' rf <- randomForest::randomForest(x = x_train, y = y_train) 
#' 
#' # Run apply_lime
#' res <- apply_lime(train = x_train, 
#'                   test = x_test, 
#'                   model = rf,
#'                   label = "1",
#'                   n_features = 2,
#'                   sim_method = 'quantile_bins',
#'                   nbins = 2:3, 
#'                   gower_pow = c(1, 5))
#'
#' # Plot metrics to compare LIME implementations
#' plot_metrics(res$explain)
#'
#' # Return a plot with only the MSEE values
#' plot_metrics(res$explain, metrics = "msee")


plot_metrics <- function(explanations, metrics = 'all', add_lines = FALSE, 
                         rank_legend = 'continuous', point_size = 2,
                         line_size = 0.5, line_alpha = 1){

  # Checks
  checkmate::expect_data_frame(explanations)
  checkmate::expect_character(metrics)
  if (!all(metrics %in% c("all", "ave_r2", "msee", "ave_fidelity"))) {
    stop("metrics specified incorrectly. Must be a character vector with options of 'ave_r2', 'msee', 'ave_fidelity'.")
  } 
  
  # If metrics is not specified
  if ("all" %in% metrics) metrics = c("ave_r2", "msee", "ave_fidelity")

  # Obtain the comparison metrics
  metric_data <- compute_metrics(explanations, metrics)

  # Prepare the data for the plot
  plot_data <- metric_data %>%
    tidyr::pivot_longer(names_to = "metric", 
                        values_to = "value", 
                        metrics) %>%
    filter(.data$metric %in% metrics) %>%
    mutate(metric = factor(.data$metric),
           nbins = factor(.data$nbins),
           metric = ifelse(.data$metric == "ave_r2", "Average R2",
                    ifelse(.data$metric == "msee", "MSEE",
                           "Average Fidelity")),
           sim_method =
             ifelse(.data$sim_method == "quantile_bins", "Quantile Bins",
                    ifelse(.data$sim_method == "equal_bins", "Equal Bins",
                           ifelse(.data$sim_method == "kernel_density", "Kernel",
                                  "Normal"))) %>% factor(),
           sim_method_plot = factor(ifelse(.data$sim_method %in% c("Kernel", "Normal"),
                                           "Density",
                                           as.character(.data$sim_method))),
           nbins_plot = factor(ifelse(is.na(.data$nbins),
                                      as.character(.data$sim_method),
                                      as.character(.data$nbins)))) %>%
    mutate(metric = factor(.data$metric, levels = c("Average R2", "Average Fidelity", "MSEE"))) %>%
    mutate(ranking_value = ifelse(.data$metric == "Average R2", -.data$value, .data$value)) %>%
    group_by(.data$metric) %>%
    mutate(rank = rank(.data$ranking_value),
           gower_pow = factor(.data$gower_pow)) %>%
    arrange(.data$metric, .data$value)

  # Create the comparison plot based on whether 1 or more gower power 
  # was specified
  if (length(unique(plot_data$gower_pow)) == 1) {
    plot <- ggplot(plot_data, 
                   aes(x = .data$nbins_plot, y = .data$value))
  } else {
    plot <- ggplot(plot_data, 
                   aes(x = .data$nbins_plot, y = .data$value, shape = .data$gower_pow)) + 
      labs(shape = "Gower \nPower")
  }
  
  # Add lines to the plot if requested
  if (add_lines == TRUE) {
    plot <-
      plot + geom_line(aes(
        group = factor(.data$gower_pow),
        linetype = factor(.data$gower_pow)
      ),
      size = line_size,
      alpha = line_alpha) +
      labs(linetype = "Gower \nPower")
  }
  
  # Add points and color based on the number of ranks
  if (rank_legend == "continuous") {
    plot <- plot + geom_point(aes(color = .data$rank), size = point_size) + 
      scale_color_gradient(low = "black", high = "grey85")
  } else if (rank_legend == "discrete") {
    plot <- plot + geom_point(aes(color = factor(.data$rank)), size = point_size) + 
      scale_colour_grey()
  } else {
    stop("'rank_legend' specified incorrectly. Must by 'continuous' or 'discrete'.")
  }
  
  # Add the additional layers to the plot
  plot +
    facet_grid(.data$metric ~ .data$sim_method_plot, 
               scales = "free", space = "free_x", switch = "y") +
    theme_grey() +
    theme(strip.placement = "outside", 
          strip.background = element_blank()) +
    labs(x = "Number of Bins",
         y = "",
         color = "Rank \n(within \na metric)")

}
goodekat/limeaid documentation built on March 26, 2021, 10:45 p.m.