#' Generate Feature Importance Score-Rank Plot
#'
#' This function generates a figure which shows the true observed data plotted against the permuted data, by rank. The intersection of the true data with the upper quartile is shown, which we recommend as a significance cutoff.
#' Note: There are ample parameters for controlling the axes scale, label location, and zoom, because of data variability, you will almost certainly have to adjust these to fit your plot.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @param xlimitmin a numerical value for the minimum x value in the figure, default is 1 because ranking starts at 1.
#' @param xlimitmax a numerical value for the maximum x value in the figure, default is 500 (this is effectively the total number of features to include)
#' @param ylimitmin a numerical value for the minimum y value in the figure, default is -10
#' @param xtickbreaks a numerical value for the break spacing between label ticks on the x axis, default is 10
#' @param ylimitmax a numerical value for the maximum y value in the figure, default is 0, but this will vary largely depending on your dataset, set it larger then narrow down
#' @param labelverticaladjust a numerical value specifying the vertical (y-axis) adjustment of the intersection with upper quartile label
#' @param labelhorizontaladjust a numerical value specifying the horizontal (x-axis) adjustment of the intersection with upper quartile label
#' @param indvPermScoresOn logical TRUE or FALSE, whethere or not to include the lines for each individual permutation feature importance score (shown in light blue)
#' @param logOn logical TRUE or FALSE, whether or not to use the log or unlogged data. Default to TRUE.
#' @param focusedView logical TRUE or FALSE, whether or not to show the full plot or a focused view of only the features above the cutoff line.
#' @import ggplot2
#' @import dplyr
#' @examples
#' fiplot_full <- generate_fi_rank_plot(feat_importances$permuted_importances, quantile_data, xlimitmin=1, xlimitmax = 10, ylimitmin= -5, ylimitmax = 0, labelhorizontaladjust = -0.05,labelverticaladjust = 1.5,focusedView = FALSE,logOn = TRUE)
#' fiplot_focused <- generate_fi_rank_plot(feat_importances$permuted_importances, quantile_data, xlimitmin=1, xlimitmax = 10, ylimitmin= -5, ylimitmax = 0, labelhorizontaladjust = 1.05,labelverticaladjust = 1.5,focusedView = TRUE,logOn = TRUE)
#' @return A plot of log feature importance score by feature rank showing the upper quartile intersection which we recommend to use for a cutoff.
#' @export
generate_fi_rank_plot<-function(permutedvalues,quantiledata,xlimitmin=1,xlimitmax=500,ylimitmin=-10,ylimitmax=0,xtickbreaks=10,labelverticaladjust=1.05,labelhorizontaladjust=1.05,indvPermScoresOn=TRUE, logOn=TRUE, focusedView=FALSE){
if(focusedView==FALSE){ # FULL View
if(logOn==TRUE){
qdata<-quantiledata
qdata2<- qdata %>% dplyr::filter(mean > 0) %>% dplyr::filter(is.finite(logobserved))
permdata<-permutedvalues
permdata2<-permdata %>% dplyr::filter(feature_rank %in% qdata2$feature_rank) %>% dplyr::filter(is.finite(log_feature_importance))
message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
p<-ggplot2::ggplot(qdata2, ggplot2::aes(x = feature_rank, y = logmean)) +
xlab("Feature Rank")+
ylab("Feature Importance Score (Log-Scaled)")+
{if(indvPermScoresOn==TRUE){
geom_line(data = permdata2,
ggplot2::aes(x=feature_rank, group=permutation, y = log_feature_importance), alpha=.8, col="lightblue")
}}+
geom_ribbon(aes(ymin=loglower, ymax=logupper),fill="grey20",alpha=0.6) +
geom_line() +
geom_line(aes(x=feature_rank, y = logobserved), colour = "gold3") +
ylim(ylimitmin,ylimitmax) +
geom_vline(xintercept = which(qdata2$logobserved<qdata2$logupper)[1]-1,color="red") +
#xlim(xlimitmin,xlimitmax)+
scale_x_continuous(
breaks = seq(from = xlimitmin, to = xlimitmax, by = xtickbreaks),
limits = c(xlimitmin, xlimitmax)) +
annotate(x=which(qdata2$logobserved<qdata2$logupper)[1]-1,y=+Inf,label=paste0("No. Features above \nalpha threshold: ",which(qdata2$logobserved<qdata2$logupper)[1]-1),vjust=labelverticaladjust,hjust=labelhorizontaladjust,geom="label",size=3.5) +
theme_linedraw() +
NULL
return(p)
} else { # Full View not-log
q_data<-quantiledata
q_data2<-q_data %>% dplyr::filter(mean > 0) %>% dplyr::filter(observed>0)
perm_data<-permutedvalues
perm_data2 <- perm_data %>% dplyr::filter(feature_rank %in% q_data2$feature_rank) %>% dplyr::filter(feature_importance>0)
message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
p<-ggplot2::ggplot(q_data2, aes(x = feature_rank, y = mean)) +
xlab("Feature Rank")+
ylab("Feature Importance Score")+
{if(indvPermScoresOn==TRUE){
geom_line(data = perm_data2,
aes(x=feature_rank, group=permutation, y = feature_importance), alpha=.8, col="lightblue")
}}+
geom_ribbon(aes(ymin=lower, ymax=upper),fill="grey20",alpha=0.6) +
geom_line() +
geom_line(aes(x=feature_rank, y = observed), colour = "gold3") +
ylim(ylimitmin,ylimitmax) +
geom_vline(xintercept = which(q_data2$observed<q_data2$upper)[1]-1,color="red") +
#xlim(xlimitmin,xlimitmax)+
scale_x_continuous(
breaks = seq(from = xlimitmin, to = xlimitmax, by = xtickbreaks),
limits = c(xlimitmin, xlimitmax)) +
annotate(x=which(q_data2$observed<q_data2$upper)[1]-1,y=+Inf,label=paste0("No. Features above \nalpha threshold: ",which(q_data2$observed<q_data2$upper)[1]-1),vjust=labelverticaladjust,hjust=labelhorizontaladjust,geom="label",size=3.5) +
theme_linedraw()+
NULL
return(p)
}
} else { # Focused View log
if (logOn == TRUE) {
qdata <- quantiledata
qdata2 <- qdata %>% dplyr::filter(mean > 0) %>% filter(is.finite(logobserved))
permdata <- permutedvalues
permdata2 <- permdata %>% dplyr::filter(feature_rank %in% qdata2$feature_rank) %>% dplyr::filter(is.finite(log_feature_importance))
# Calculate permuted mean
permuted_mean <- permdata2 %>% group_by(feature_rank) %>% summarise(permuted_mean = mean(log_feature_importance, na.rm = TRUE))
intercept_calculation <- which(qdata2$logobserved < qdata2$logupper)
red_line_x_intercept <- if (length(intercept_calculation) > 0) intercept_calculation[1] - 1 else NA
first_x_value <- min(qdata2$feature_rank, na.rm = TRUE)
if (is.na(red_line_x_intercept)) {
stop("Red line x-intercept could not be calculated. Please check your data.")
}
message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
p <- ggplot(qdata2, aes(x = feature_rank, y = logmean)) +
xlab("Feature Rank") +
ylab("Feature Importance Score (Log-Scaled)") +
{if (indvPermScoresOn) {
geom_line(data = permdata2, aes(x = feature_rank, group = permutation, y = log_feature_importance), alpha = .8, col = "lightblue")
}}+
geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.6) +
geom_line()+
#geom_line(aes(x=feature_rank, y = logmean), colour = "black") +
geom_line(aes(x = feature_rank, y = logobserved), colour = "gold3") +
#geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue") +
ylim(ylimitmin, ylimitmax) +
geom_vline(xintercept = red_line_x_intercept, color = "red") +
#xlim(first_x_value, red_line_x_intercept) +
scale_x_continuous(
breaks = seq(from = first_x_value, to = red_line_x_intercept, by = xtickbreaks),
limits = c(first_x_value, red_line_x_intercept)) +
annotate("label", x = red_line_x_intercept, y = ylimitmax, label = paste0("No. Features above alpha threshold: ", red_line_x_intercept), vjust = labelverticaladjust, hjust = labelhorizontaladjust, size = 3.5) +
theme_linedraw()
# p<- p + geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.3)
#p<- p + geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue")
return(p)
} else { # Focused View not-log
q_data <- quantiledata
q_data2<- qdata %>% dplyr::filter(mean > 0) %>% dplyr::filter(is.finite(observed))
#q_data2 <- q_data %>% dplyr::filter(mean > 0) %>% dplyr::filter(observed > 0)
perm_data <- permutedvalues
perm_data2 <- perm_data %>% dplyr::filter(feature_rank %in% q_data2$feature_rank) %>% dplyr::filter(feature_importance > 0)
# Calculate permuted mean
permuted_mean <- perm_data2 %>% group_by(feature_rank) %>% dplyr::summarise(permuted_mean = mean(feature_importance, na.rm = TRUE))
intercept_calculation <- which(q_data2$observed < q_data2$upper)
red_line_x_intercept <- if (length(intercept_calculation) > 0) intercept_calculation[1] - 1 else NA
first_x_value <- min(q_data2$feature_rank, na.rm = TRUE)
if (is.na(red_line_x_intercept)) {
stop("Red line x-intercept could not be calculated. Please check your data.")
}
message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
p <- ggplot(q_data2, aes(x = feature_rank, y = mean)) +
xlab("Feature Rank") +
ylab("Feature Importance Score") +
{if (indvPermScoresOn) {
p <- p + geom_line(data = perm_data2, aes(x = feature_rank, group = permutation, y = feature_importance), alpha = .8, col = "lightblue")
}}+
geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.6) +
geom_line()+
geom_line(aes(x = feature_rank, y = observed), colour = "gold3") +
#geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue") +
ylim(ylimitmin, ylimitmax) +
geom_vline(xintercept = red_line_x_intercept, color = "red") +
#xlim(first_x_value, red_line_x_intercept) +
scale_x_continuous(
breaks = seq(from = first_x_value, to = red_line_x_intercept, by = xtickbreaks),
limits = c(first_x_value, red_line_x_intercept)) +
annotate("label", x = red_line_x_intercept, y = ylimitmax, label = paste0("No. Features above alpha threshold: ", red_line_x_intercept), vjust = labelverticaladjust, hjust = labelhorizontaladjust, size = 3.5) +
theme_linedraw()
#p<- p + geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue")
return(p)
}
}
}
#' Generate pi Histogram Plot
#'
#' Creates a histogram showing the sum of absolute deviations by count, the measures used to calculate the pi statistics used in the p-value calculation for the set.
#'
#' This function generates a histogram illustrating the sum of absolute deviations in the permuted set vs the true (observed) set. It is a visualization of the data used to calculate the p-value for the entire feature set.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @return A histogram of count by sum of absolute deviations, showing the differences between the null and observed datasets.
#' @examples
#' pihist_plot<-generate_pi_histogram(feat_importances$permuted_importances, quantile_data)
#' @import ggplot2
#' @import dplyr
#' @export
generate_pi_histogram<-function(permutedvalues,quantiledata){
d<-permutedvalues
numberofPermutations<-max(d$permutation)
d <- d %>% ungroup() %>%
mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>%
mutate(Dev = feature_importance - Mean)
d <- d %>% mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>% mutate(Dev = feature_importance - Mean)
pi_permuted <- d %>% group_by(permutation) %>% summarise(sum_abs_deviations = sum(abs(Dev)))
pi_obs<-sum(abs(quantiledata$observed - quantiledata$mean))
ggplot2::ggplot(pi_permuted,aes(x = sum_abs_deviations)) +
geom_histogram(bins=numberofPermutations) +
xlab("Sum of Absolute Deviations")+
ylab("Count")+
geom_vline(xintercept = pi_obs,color="red") +
annotate(x=pi_obs,y=+Inf,label=paste0("pi_obs: ",round(pi_obs,2)),vjust=2,hjust=1.2,geom="label") +
theme_linedraw()+
NULL
}
#' Generate ECDF Plot
#'
#' Creates a plot of the empirical cumulative density function showing the sum of absolute deviations by fraction of data.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @return A plot of the empirical cumulative density function showing the sum of absolute deviations by fraction of data.
#' @examples
#' ecdf_plot<-generate_pi_ECDF_plot(feat_importances$permuted_importances, quantile_data)
#' @import ggplot2
#' @import dplyr
#' @export
generate_pi_ECDF_plot<-function(permutedvalues,quantiledata){
d<-permutedvalues
numberofPermutations<-max(d$permutation)
d1 <- d %>% ungroup() %>%
mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>%
mutate(Dev = feature_importance - Mean)
d2 <- d1 %>% mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>% mutate(Dev = feature_importance - Mean)
pi_permuted <- d2 %>% group_by(permutation) %>% summarise(Sum_abs_deviations = sum(abs(Dev)))
pi_obs<-sum(abs(quantiledata$observed - quantiledata$mean))
ggplot2::ggplot(pi_permuted,aes(x = Sum_abs_deviations)) +
stat_ecdf(geom = "step")+
geom_vline(xintercept = pi_obs,color="red") +
annotate(x=pi_obs,y=+Inf,label=paste0("pi_obs: ",round(pi_obs,2)),vjust=4,hjust=1.1,geom="label") +
labs(y = "Fraction of Data", x="Sum of Absolute Deviations")+
theme_linedraw()+
NULL
}
#' Generate SHAP Plots
#'
#' Creates a combined bar and beeswarm plot to show global and local feature importance
#'
#' @param mean_shap_values A dataframe containing mean SHAP values and features
#' @param long_shap_data A dataframe containing individual SHAP values for features
#' @param title_global The title for the global feature importance plot
#' @param title_local The title for the local feature explanation plot
#' @param fill_colors A vector of colors to use for filling the bar plot
#' @param gradient_colors A vector of colors to use for the color gradient in the beeswarm plot
#' @return A combined ggplot object with a bar plot and a beeswarm plot
#' @examples
#' shapplot<-generate_shap_plots(mean_shap_values = shapvals$significant_features,long_shap_data = shapvals$long_shap_data)
#' @import ggplot2
#' @import dplyr
#' @import cowplot
#' @import ggbeeswarm
#' @export
generate_shap_plots <- function(mean_shap_values, long_shap_data, title_global = "SHAP Global Feature Importance",
title_local = "SHAP Local Feature Explanation", fill_colors = c("blue", "red"),
gradient_colors = c("blue", "red")) {
# Ensure consistent factor levels for the 'feature' variable
feature_levels <- mean_shap_values$feature[order(-mean_shap_values$abs_mean_shap)]
long_shap_data$feature <- factor(long_shap_data$feature, levels = feature_levels)
mean_shap_values$feature <- factor(mean_shap_values$feature, levels = feature_levels)
# Global feature importance plot (Bar plot)
bar_plot <- ggplot(mean_shap_values, aes(x = feature, y = abs_mean_shap, fill = mean_shap > 0)) +
geom_col() +
coord_flip() +
scale_fill_manual(values = fill_colors, name = "mean_shap > 0") +
theme_minimal() +
theme(
axis.text.y = element_text(hjust = 1),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
plot.margin = margin(0, 0, 0, 0),
axis.title.y = element_text(margin = margin(t = 0, r = 0, b = 0, l = 10)),
legend.position = "bottom"
) +
labs(x = "Feature", y = "Mean Abs. SHAP value", title = title_global) +
theme(plot.title = element_text(hjust = 0.5, vjust = -2)) # Adjust title position
# Local explanation summary plot (Beeswarm plot)
beeswarm_plot <- ggplot(long_shap_data, aes(x = feature, y = shap_value, color = shap_value)) +
geom_quasirandom(size = 1, alpha = 0.5) +
coord_flip() +
scale_color_gradient(low = gradient_colors[1], high = gradient_colors[2], name = "Feature Value") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
plot.margin = margin(0, 0, 0, 0),
legend.position = "bottom",
legend.key.width = unit(1, "cm"),
legend.key.height = unit(0.5, "cm"),
legend.title = element_text(size = 10),
legend.text = element_text(size = 8)
) +
labs(x = "", y = "SHAP value", title="SHAP Local Feature Explanation") +
theme(plot.title = element_text(hjust = 0.5, vjust = -2)) # Adjust title position
# Combine the plots side by side using cowplot
combined_plot <- plot_grid(bar_plot, beeswarm_plot, align = 'h', ncol = 2, rel_widths = c(0.5, 1))
# Return the combined plot
return(combined_plot)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.