R/generateFigures.R

Defines functions generate_shap_plots generate_pi_ECDF_plot generate_pi_histogram generate_fi_rank_plot

Documented in generate_fi_rank_plot generate_pi_ECDF_plot generate_pi_histogram generate_shap_plots

#' Generate Feature Importance Score-Rank Plot
#'
#' This function generates a figure which shows the true observed data plotted against the permuted data, by rank. The intersection of the true data with the upper quartile is shown, which we recommend as a significance cutoff.
#' Note: There are ample parameters for controlling the axes scale, label location, and zoom, because of data variability, you will almost certainly have to adjust these to fit your plot.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @param xlimitmin a numerical value for the minimum x value in the figure, default is 1 because ranking starts at 1.
#' @param xlimitmax a numerical value for the maximum x value in the figure, default is 500 (this is effectively the total number of features to include)
#' @param ylimitmin a numerical value for the minimum y value in the figure, default is -10
#' @param xtickbreaks a numerical value for the break spacing between label ticks on the x axis, default is 10
#' @param ylimitmax a numerical value for the maximum y value in the figure, default is 0, but this will vary largely depending on your dataset, set it larger then narrow down
#' @param labelverticaladjust a numerical value specifying the vertical (y-axis) adjustment of the intersection with upper quartile label
#' @param labelhorizontaladjust a numerical value specifying the horizontal (x-axis) adjustment of the intersection with upper quartile label
#' @param indvPermScoresOn logical TRUE or FALSE, whethere or not to include the lines for each individual permutation feature importance score (shown in light blue)
#' @param logOn logical TRUE or FALSE, whether or not to use the log or unlogged data. Default to TRUE.
#' @param focusedView logical TRUE or FALSE, whether or not to show the full plot or a focused view of only the features above the cutoff line.
#' @import ggplot2
#' @import dplyr
#' @examples
#' fiplot_full <- generate_fi_rank_plot(feat_importances$permuted_importances, quantile_data, xlimitmin=1, xlimitmax = 10, ylimitmin= -5, ylimitmax = 0, labelhorizontaladjust = -0.05,labelverticaladjust = 1.5,focusedView = FALSE,logOn = TRUE)
#' fiplot_focused <- generate_fi_rank_plot(feat_importances$permuted_importances, quantile_data, xlimitmin=1, xlimitmax = 10, ylimitmin= -5, ylimitmax = 0, labelhorizontaladjust = 1.05,labelverticaladjust = 1.5,focusedView = TRUE,logOn = TRUE)
#' @return A plot of log feature importance score by feature rank showing the upper quartile intersection which we recommend to use for a cutoff.
#' @export
generate_fi_rank_plot<-function(permutedvalues,quantiledata,xlimitmin=1,xlimitmax=500,ylimitmin=-10,ylimitmax=0,xtickbreaks=10,labelverticaladjust=1.05,labelhorizontaladjust=1.05,indvPermScoresOn=TRUE, logOn=TRUE, focusedView=FALSE){
  if(focusedView==FALSE){ # FULL View
    if(logOn==TRUE){
    qdata<-quantiledata
    qdata2<- qdata %>% dplyr::filter(mean > 0) %>% dplyr::filter(is.finite(logobserved))
    permdata<-permutedvalues
    permdata2<-permdata %>% dplyr::filter(feature_rank %in% qdata2$feature_rank) %>% dplyr::filter(is.finite(log_feature_importance))
    message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
      p<-ggplot2::ggplot(qdata2, ggplot2::aes(x = feature_rank, y = logmean)) +
      xlab("Feature Rank")+
      ylab("Feature Importance Score (Log-Scaled)")+
      {if(indvPermScoresOn==TRUE){
      geom_line(data = permdata2,
                ggplot2::aes(x=feature_rank, group=permutation, y = log_feature_importance), alpha=.8, col="lightblue")
      }}+
      geom_ribbon(aes(ymin=loglower, ymax=logupper),fill="grey20",alpha=0.6) +
      geom_line() +
      geom_line(aes(x=feature_rank, y = logobserved), colour = "gold3") +
      ylim(ylimitmin,ylimitmax) +
      geom_vline(xintercept = which(qdata2$logobserved<qdata2$logupper)[1]-1,color="red") +
      #xlim(xlimitmin,xlimitmax)+
      scale_x_continuous(
          breaks = seq(from = xlimitmin, to = xlimitmax, by = xtickbreaks),
          limits = c(xlimitmin, xlimitmax)) +
      annotate(x=which(qdata2$logobserved<qdata2$logupper)[1]-1,y=+Inf,label=paste0("No. Features above \nalpha threshold: ",which(qdata2$logobserved<qdata2$logupper)[1]-1),vjust=labelverticaladjust,hjust=labelhorizontaladjust,geom="label",size=3.5) +
      theme_linedraw() +
      NULL
      return(p)
    } else { # Full View not-log
        q_data<-quantiledata
        q_data2<-q_data %>% dplyr::filter(mean > 0) %>% dplyr::filter(observed>0)
        perm_data<-permutedvalues
        perm_data2 <- perm_data %>% dplyr::filter(feature_rank %in% q_data2$feature_rank) %>% dplyr::filter(feature_importance>0)
        message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
        p<-ggplot2::ggplot(q_data2, aes(x = feature_rank, y = mean)) +
        xlab("Feature Rank")+
        ylab("Feature Importance Score")+
        {if(indvPermScoresOn==TRUE){
          geom_line(data = perm_data2,
                    aes(x=feature_rank, group=permutation, y = feature_importance), alpha=.8, col="lightblue")
        }}+
        geom_ribbon(aes(ymin=lower, ymax=upper),fill="grey20",alpha=0.6) +
        geom_line() +
        geom_line(aes(x=feature_rank, y = observed), colour = "gold3") +
        ylim(ylimitmin,ylimitmax) +
        geom_vline(xintercept = which(q_data2$observed<q_data2$upper)[1]-1,color="red") +
        #xlim(xlimitmin,xlimitmax)+
        scale_x_continuous(
            breaks = seq(from = xlimitmin, to = xlimitmax, by = xtickbreaks),
            limits = c(xlimitmin, xlimitmax)) +
        annotate(x=which(q_data2$observed<q_data2$upper)[1]-1,y=+Inf,label=paste0("No. Features above \nalpha threshold: ",which(q_data2$observed<q_data2$upper)[1]-1),vjust=labelverticaladjust,hjust=labelhorizontaladjust,geom="label",size=3.5) +
        theme_linedraw()+
        NULL
        return(p)
    }
  } else { # Focused View log
    if (logOn == TRUE) {
      qdata <- quantiledata
      qdata2 <- qdata %>% dplyr::filter(mean > 0) %>% filter(is.finite(logobserved))
      permdata <- permutedvalues
      permdata2 <- permdata %>% dplyr::filter(feature_rank %in% qdata2$feature_rank) %>% dplyr::filter(is.finite(log_feature_importance))

      # Calculate permuted mean
      permuted_mean <- permdata2 %>% group_by(feature_rank) %>% summarise(permuted_mean = mean(log_feature_importance, na.rm = TRUE))

      intercept_calculation <- which(qdata2$logobserved < qdata2$logupper)
      red_line_x_intercept <- if (length(intercept_calculation) > 0) intercept_calculation[1] - 1 else NA
      first_x_value <- min(qdata2$feature_rank, na.rm = TRUE)

      if (is.na(red_line_x_intercept)) {
        stop("Red line x-intercept could not be calculated. Please check your data.")
      }
      message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
      p <- ggplot(qdata2, aes(x = feature_rank, y = logmean)) +
        xlab("Feature Rank") +
        ylab("Feature Importance Score (Log-Scaled)") +
        {if (indvPermScoresOn) {
          geom_line(data = permdata2, aes(x = feature_rank, group = permutation, y = log_feature_importance), alpha = .8, col = "lightblue")
        }}+
        geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.6) +
        geom_line()+
        #geom_line(aes(x=feature_rank, y = logmean), colour = "black") +
        geom_line(aes(x = feature_rank, y = logobserved), colour = "gold3") +
        #geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue") +
        ylim(ylimitmin, ylimitmax) +
        geom_vline(xintercept = red_line_x_intercept, color = "red") +
        #xlim(first_x_value, red_line_x_intercept) +
        scale_x_continuous(
          breaks = seq(from = first_x_value, to = red_line_x_intercept, by = xtickbreaks),
          limits = c(first_x_value, red_line_x_intercept)) +
        annotate("label", x = red_line_x_intercept, y = ylimitmax, label = paste0("No. Features above alpha threshold: ", red_line_x_intercept), vjust = labelverticaladjust, hjust = labelhorizontaladjust, size = 3.5) +
        theme_linedraw()
     # p<- p + geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.3)
      #p<- p + geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue")

      return(p)
    } else { # Focused View not-log
      q_data <- quantiledata
      q_data2<- qdata %>% dplyr::filter(mean > 0) %>% dplyr::filter(is.finite(observed))
      #q_data2 <- q_data %>% dplyr::filter(mean > 0) %>% dplyr::filter(observed > 0)
      perm_data <- permutedvalues
      perm_data2 <- perm_data %>% dplyr::filter(feature_rank %in% q_data2$feature_rank) %>% dplyr::filter(feature_importance > 0)

      # Calculate permuted mean
      permuted_mean <- perm_data2 %>% group_by(feature_rank) %>% dplyr::summarise(permuted_mean = mean(feature_importance, na.rm = TRUE))

      intercept_calculation <- which(q_data2$observed < q_data2$upper)
      red_line_x_intercept <- if (length(intercept_calculation) > 0) intercept_calculation[1] - 1 else NA
      first_x_value <- min(q_data2$feature_rank, na.rm = TRUE)

      if (is.na(red_line_x_intercept)) {
        stop("Red line x-intercept could not be calculated. Please check your data.")
      }
      message("Warnings regarding rows containing missing values (`geom_line()`) are related to values plotted outside axis limits and can be ignored once you are happy with the way your plot looks.")
      p <- ggplot(q_data2, aes(x = feature_rank, y = mean)) +
        xlab("Feature Rank") +
        ylab("Feature Importance Score") +
        {if (indvPermScoresOn) {
          p <- p + geom_line(data = perm_data2, aes(x = feature_rank, group = permutation, y = feature_importance), alpha = .8, col = "lightblue")
        }}+
        geom_ribbon(aes(ymin = loglower, ymax = logupper), fill = "grey20", alpha = 0.6) +
        geom_line()+
        geom_line(aes(x = feature_rank, y = observed), colour = "gold3") +
        #geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue") +
        ylim(ylimitmin, ylimitmax) +
        geom_vline(xintercept = red_line_x_intercept, color = "red") +
        #xlim(first_x_value, red_line_x_intercept) +
        scale_x_continuous(
          breaks = seq(from = first_x_value, to = red_line_x_intercept, by = xtickbreaks),
          limits = c(first_x_value, red_line_x_intercept)) +
        annotate("label", x = red_line_x_intercept, y = ylimitmax, label = paste0("No. Features above alpha threshold: ", red_line_x_intercept), vjust = labelverticaladjust, hjust = labelhorizontaladjust, size = 3.5) +
        theme_linedraw()
      #p<- p + geom_line(data = permuted_mean, aes(x = feature_rank, y = permuted_mean), color = "darkblue")
      return(p)
    }
  }
}

#' Generate pi Histogram Plot
#'
#' Creates a histogram showing the sum of absolute deviations by count, the measures used to calculate the pi statistics used in the p-value calculation for the set.
#'
#' This function generates a histogram illustrating the sum of absolute deviations in the permuted set vs the true (observed) set. It is a visualization of the data used to calculate the p-value for the entire feature set.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @return A histogram of count by sum of absolute deviations, showing the differences between the null and observed datasets.
#' @examples
#' pihist_plot<-generate_pi_histogram(feat_importances$permuted_importances, quantile_data)
#' @import ggplot2
#' @import dplyr
#' @export
generate_pi_histogram<-function(permutedvalues,quantiledata){
    d<-permutedvalues
    numberofPermutations<-max(d$permutation)
    d <- d %>% ungroup() %>%
      mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>%
      mutate(Dev = feature_importance - Mean)
    d <- d %>% mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>% mutate(Dev = feature_importance - Mean)
    pi_permuted <- d %>% group_by(permutation) %>% summarise(sum_abs_deviations = sum(abs(Dev)))
    pi_obs<-sum(abs(quantiledata$observed - quantiledata$mean))
      ggplot2::ggplot(pi_permuted,aes(x = sum_abs_deviations)) +
      geom_histogram(bins=numberofPermutations) +
      xlab("Sum of Absolute Deviations")+
      ylab("Count")+
      geom_vline(xintercept = pi_obs,color="red") +
      annotate(x=pi_obs,y=+Inf,label=paste0("pi_obs: ",round(pi_obs,2)),vjust=2,hjust=1.2,geom="label") +
      theme_linedraw()+
      NULL
}

#' Generate ECDF Plot
#'
#' Creates a plot of the empirical cumulative density function showing the sum of absolute deviations by fraction of data.
#'
#' @param permutedvalues a tibble of feature ranks, feature names, feature importance scores, log feature importance scores and permutation # generated by the load_permuted_fws function
#' @param quantiledata a tibble containing: feature ranks, mean, lower, upper, and observed as well as logmean, loglower, logupper, logobserved generated by the calculateQuantiles function
#' @return A plot of the empirical cumulative density function showing the sum of absolute deviations by fraction of data.
#' @examples
#' ecdf_plot<-generate_pi_ECDF_plot(feat_importances$permuted_importances, quantile_data)
#' @import ggplot2
#' @import dplyr
#' @export
generate_pi_ECDF_plot<-function(permutedvalues,quantiledata){
  d<-permutedvalues
  numberofPermutations<-max(d$permutation)
  d1 <- d %>% ungroup() %>%
    mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>%
    mutate(Dev = feature_importance - Mean)
  d2 <- d1 %>% mutate(Mean = rep( quantiledata$mean, times = numberofPermutations) ) %>% mutate(Dev = feature_importance - Mean)
  pi_permuted <- d2 %>% group_by(permutation) %>% summarise(Sum_abs_deviations = sum(abs(Dev)))
  pi_obs<-sum(abs(quantiledata$observed - quantiledata$mean))
  ggplot2::ggplot(pi_permuted,aes(x = Sum_abs_deviations)) +
    stat_ecdf(geom = "step")+
    geom_vline(xintercept = pi_obs,color="red") +
    annotate(x=pi_obs,y=+Inf,label=paste0("pi_obs: ",round(pi_obs,2)),vjust=4,hjust=1.1,geom="label") +
    labs(y = "Fraction of Data", x="Sum of Absolute Deviations")+
    theme_linedraw()+
    NULL
}

#' Generate SHAP Plots
#'
#' Creates a combined bar and beeswarm plot to show global and local feature importance
#'
#' @param mean_shap_values A dataframe containing mean SHAP values and features
#' @param long_shap_data A dataframe containing individual SHAP values for features
#' @param title_global The title for the global feature importance plot
#' @param title_local The title for the local feature explanation plot
#' @param fill_colors A vector of colors to use for filling the bar plot
#' @param gradient_colors A vector of colors to use for the color gradient in the beeswarm plot
#' @return A combined ggplot object with a bar plot and a beeswarm plot
#' @examples
#' shapplot<-generate_shap_plots(mean_shap_values = shapvals$significant_features,long_shap_data = shapvals$long_shap_data)
#' @import ggplot2
#' @import dplyr
#' @import cowplot
#' @import ggbeeswarm
#' @export
generate_shap_plots <- function(mean_shap_values, long_shap_data, title_global = "SHAP Global Feature Importance",
                                title_local = "SHAP Local Feature Explanation", fill_colors = c("blue", "red"),
                                gradient_colors = c("blue", "red")) {

  # Ensure consistent factor levels for the 'feature' variable
  feature_levels <- mean_shap_values$feature[order(-mean_shap_values$abs_mean_shap)]
  long_shap_data$feature <- factor(long_shap_data$feature, levels = feature_levels)
  mean_shap_values$feature <- factor(mean_shap_values$feature, levels = feature_levels)

  # Global feature importance plot (Bar plot)
  bar_plot <- ggplot(mean_shap_values, aes(x = feature, y = abs_mean_shap, fill = mean_shap > 0)) +
    geom_col() +
    coord_flip() +
    scale_fill_manual(values = fill_colors, name = "mean_shap > 0") +
    theme_minimal() +
    theme(
      axis.text.y = element_text(hjust = 1),
      axis.ticks.y = element_blank(),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      plot.margin = margin(0, 0, 0, 0),
      axis.title.y = element_text(margin = margin(t = 0, r = 0, b = 0, l = 10)),
      legend.position = "bottom"
    ) +
    labs(x = "Feature", y = "Mean Abs. SHAP value", title = title_global) +
    theme(plot.title = element_text(hjust = 0.5, vjust = -2)) # Adjust title position

  # Local explanation summary plot (Beeswarm plot)
  beeswarm_plot <- ggplot(long_shap_data, aes(x = feature, y = shap_value, color = shap_value)) +
    geom_quasirandom(size = 1, alpha = 0.5) +
    coord_flip() +
    scale_color_gradient(low = gradient_colors[1], high = gradient_colors[2], name = "Feature Value") +
    theme_minimal() +
    theme(
      axis.text.y = element_blank(),
      axis.ticks.y = element_blank(),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      plot.margin = margin(0, 0, 0, 0),
      legend.position = "bottom",
      legend.key.width = unit(1, "cm"),
      legend.key.height = unit(0.5, "cm"),
      legend.title = element_text(size = 10),
      legend.text = element_text(size = 8)
    ) +
    labs(x = "", y = "SHAP value", title="SHAP Local Feature Explanation") +
    theme(plot.title = element_text(hjust = 0.5, vjust = -2)) # Adjust title position

  # Combine the plots side by side using cowplot
  combined_plot <- plot_grid(bar_plot, beeswarm_plot, align = 'h', ncol = 2, rel_widths = c(0.5, 1))

  # Return the combined plot
  return(combined_plot)
}
tkolisnik/Rf2pval documentation built on Feb. 20, 2024, 5:39 a.m.