#' Heatmap of LIME Selected Features
#'
#' Creates a heatmap of the features selected by lime for all
#' observations in the test set across all of the different
#' LIME implementations.
#'
#' @param explanations Explain dataframe from the list returned by apply_lime.
#' @param feature_nums A vector of integer values from 1 to
#' \code{nfeatures} (specified in \code{apply_lime}) to
#' determine which features selected by LIME should be
#' included in the plot.
#' @param facet_var A categorical variable that is the same length as
#' the data input to apply_lime for the test argument that
#' will be used to facet the heatmap. (NULL by default)
#' @param order_method Method for ordering the predictions: either
#' "obs_num" which uses the order from the explanation
#' dataframe (default), "sort_features" which sorts by
#' the factors within a feature using the dplyr "arrange" function,
#' or one of the options from the package seriation for matrices
#' (see seriation::list_seriation_methods("matrix") for the options
#' available.)
#'
#' @importFrom checkmate expect_data_frame expect_character
#' @importFrom cluster daisy
#' @importFrom dplyr arrange_at desc left_join pull vars
#' @importFrom ggplot2 aes facet_grid geom_point geom_tile ggplot labs scale_color_grey scale_color_manual scale_fill_grey theme theme_bw
#' @importFrom seriation list_seriation_methods seriate
#' @importFrom tidyr pivot_wider
#' @importFrom tidyselect all_of
#'
#' @export plot_feature_heatmap
#'
#' @examples
#'
#' # Prepare training and testing data
#' x_train = sine_data_train[c("x1", "x2", "x3")]
#' y_train = factor(sine_data_train$y)
#' x_test = sine_data_test[1:5, c("x1", "x2", "x3")]
#'
#' # Fit a random forest model
#' rf <- randomForest::randomForest(x = x_train, y = y_train)
#'
#' # Run apply_lime
#' res <- apply_lime(train = x_train,
#' test = x_test,
#' model = rf,
#' label = "1",
#' n_features = 2,
#' sim_method = c('quantile_bins',
#' 'kernel_density'),
#' nbins = 2:4)
#'
#' # Plot heatmap of selected features across LIME implementations
#' plot_feature_heatmap(res$explain)
#'
#' # Return a heatmap with only the features selected first by LIME
#' plot_feature_heatmap(res$explain, feature_num = 1)
plot_feature_heatmap <- function(explanations, feature_nums = NULL,
facet_var = NULL, order_method = "obs_num"){
# Checks
checkmate::expect_data_frame(explanations)
if (!is.null(feature_nums)) checkmate::expect_double(feature_nums)
# Prepare the explanation data for plotting
heatmap_data <- explanations %>%
select(.data$sim_method, .data$nbins, .data$gower_pow,
.data$case, .data$feature, .data$feature_weight) %>%
mutate(feature_magnitude = abs(.data$feature_weight)) %>%
group_by(.data$sim_method, .data$nbins, .data$gower_pow, .data$case) %>%
arrange(.data$sim_method, .data$nbins, .data$gower_pow,
.data$case, desc(.data$feature_magnitude)) %>%
mutate(feature_num = 1:n()) %>%
ungroup() %>%
mutate(nbins = factor(.data$nbins),
gower_pow = factor(paste("Gower \nPower \nof", .data$gower_pow)),
feature = factor(.data$feature),
sim_method =
ifelse(.data$sim_method == "quantile_bins", "Quantile Bins",
ifelse(.data$sim_method == "equal_bins", "Equal Bins",
ifelse(.data$sim_method == "kernel_density", "Kernel",
"Normal"))) %>% factor(),
sim_method_plot = factor(ifelse(.data$sim_method %in% c("Kernel", "Normal"),
"Density",
as.character(.data$sim_method))),
nbins_plot = factor(ifelse(is.na(.data$nbins),
as.character(.data$sim_method),
as.character(.data$nbins)))) %>%
mutate(case = as.character(.data$case))
# If requested add faceting variable
if (!is.null(facet_var)) {
heatmap_data <- heatmap_data %>%
left_join(data.frame(case = unique(explanations$case),
facet_var = facet_var) %>%
mutate(case = as.character(.data$case)),
by = "case")
}
# Subset the data to only keep the requested features
if (!(is.null(feature_nums))) {
min_feat_num <- min(feature_nums)
heatmap_data <- heatmap_data %>%
filter(.data$feature_num %in% feature_nums) %>%
mutate(feature_num = factor(.data$feature_num),
feature_num = paste("LIME Feature", .data$feature_num))
} else {
min_feat_num <- min(heatmap_data$feature_num)
heatmap_data <- heatmap_data %>%
mutate(feature_num = factor(.data$feature_num),
feature_num = paste("LIME Feature", .data$feature_num))
}
# If requested, determine an order for the cases using seriation
if (order_method == "obs_num") {
# Turn case into a factor and order the levels numerically
cases_order = sort(as.numeric(as.character(unique(heatmap_data$case))))
heatmap_data <- heatmap_data %>%
mutate(case = factor(.data$case, levels = cases_order))
} else if (order_method == "sort_features") {
# Prepare features for ordering
sim_features <- heatmap_data %>%
filter(.data$feature_num == paste("LIME Feature", min_feat_num)) %>%
mutate(method = paste(.data$sim_method, .data$nbins, .data$gower_pow)) %>%
select(-.data$feature_weight, -.data$feature_magnitude, -.data$sim_method,
-.data$nbins, -.data$gower_pow, -.data$sim_method_plot,
-.data$nbins_plot, -facet_var) %>%
tidyr::pivot_wider(names_from = "method", values_from = "feature")
# Determine the order of the cases
cases_order <- sim_features %>%
select(-.data$feature_num) %>%
arrange_at(vars(-.data$case)) %>%
pull(.data$case)
# Add the order to the heatmap data
heatmap_data <- heatmap_data %>%
mutate(case = as.character(.data$case)) %>%
mutate(case = factor(.data$case, levels = cases_order))
} else if (order_method %in% seriation::list_seriation_methods("matrix")) {
# Arrange the data by case and extract the unique cases
heatmap_data <- heatmap_data %>% arrange(.data$case)
cases <- unique(heatmap_data$case)
# Prepare features for ordering
feature_matrix <- heatmap_data %>%
dplyr::filter(.data$feature_num == paste("LIME Feature", min_feat_num)) %>%
dplyr::mutate(method = paste(.data$sim_method, .data$nbins, .data$gower_pow)) %>%
dplyr::select(.data$method, .data$feature, .data$case) %>%
mutate(feature = as.numeric(.data$feature)) %>%
tidyr::pivot_wider(names_from = "method", values_from = "feature") %>%
select(-.data$case) %>%
as.matrix()
# Determine the order of the cases using seriation
# default method is PCA,
# seriation::list_seriation_methods("matrix")
# lists all
features_ordered <- seriate(feature_matrix, method = order_method)
# Add the order to the heatmap data
heatmap_data <- heatmap_data %>%
mutate(case = factor(.data$case, levels = cases[features_ordered[[1]]]))
} else {
stop("order_method not specified correctly. See ?plot_feature_heatmap for available options.")
}
# Create vertical facet variable
if (length(unique(heatmap_data$gower_pow)) == 1) {
heatmap_data <- heatmap_data %>%
dplyr::mutate(vert_facet = .data$sim_method_plot)
} else {
heatmap_data <- heatmap_data %>%
dplyr::mutate(vert_facet = paste0(.data$sim_method_plot, " \n", .data$gower_pow))
}
# Create vertical facet variable
if (is.null(facet_var)) {
heatmap_data <- heatmap_data %>%
dplyr::mutate(horz_facet = .data$feature_num)
} else {
heatmap_data <- heatmap_data %>%
dplyr::mutate(horz_facet = paste0(.data$feature_num, " \n", facet_var))
}
# Create the heatmap
plot <-
ggplot(heatmap_data,
aes(x = .data$nbins_plot, y = .data$case, fill = .data$feature, color = .data$feature)) +
geom_tile() +
theme_grey() +
labs(x = "Number of Bins",
y = "Prediction Number",
fill = "Complex \nModel \nFeature",
color = "Complex \nModel \nFeature")
# Facet, add grey scale colors, and return the plot
plot +
facet_grid(.data$horz_facet ~ .data$vert_facet,
scales = "free", space = "free") +
scale_fill_grey() +
scale_color_grey()
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.