R/plots.R

Defines functions plot_shap_scatter_for_new_samples plot_shap_scatter_for_training_samples plot_shap_scatter plot_contribution_to_new_samples plot_contribution_to_training_samples plot_contribution_to_sample plot_predictions_dimplot plot_top_contributors helper_plot_top_contributors

Documented in helper_plot_top_contributors plot_contribution_to_new_samples plot_contribution_to_sample plot_contribution_to_training_samples plot_predictions_dimplot plot_shap_scatter plot_shap_scatter_for_new_samples plot_shap_scatter_for_training_samples plot_top_contributors

#' Plot top contributions (internal function)
#'
#' This function generates a plot of the top contributing features for the top n models.
#' @keywords plot
#' @import ggplot2 cowplot purrr
#' @examples
#' plot_top_contributors(shap_table, name, 10)
helper_plot_top_contributors <- function(shap_table, name, n, color = "#52565e",labels_data = NULL, sec_label = NULL){
  
  shap_table <- shap_table %>% top_n(n, wt=value)
  shap_table$term <- factor(str_to_upper(shap_table$term), levels = rev(str_to_upper(shap_table$term)))
  shap_table$labels <- abbreviate(as.character(shap_table$term), minlength = 10, dot = T)

  # Use alternative names of provided as labels_data
  if (!is.null(labels_data)){
    old_name <- name
    name <- labels_data %>% filter(old_label == old_name) %>% pull(new_label) %>% first()
    if (!is.null(sec_label)) second_label = labels_data  %>% filter(old_label == old_name) %>% pull(second_label) %>% first()
  } else {
    name <- word(name,2,sep="_")
    second_label <- NULL
  }
  
  p <- ggplot(shap_table, aes(x=term,y=value))+geom_bar(stat="identity", fill = color)+coord_flip()+
    labs(x="Feature",y="Contribution",title=str_to_upper(name),subtitle=second_label)+
    scale_x_discrete(labels = shap_table$labels)
  
  return(p)
  
}

#' Plot top contributions
#'
#' This function generates a plot of the top contributing features for the top n models.
#' @param models A list of model objects generated by make_xgb_models
#' @param models_to_use A vector of model names to restrict the plot to.
#' @param data_to_use If "training" (default), it will use the full training data. If anything else, it will use the new predictions (must have used add_predictions before).
#' @param n_predictors The maximum number of predictors to show per plot.
#' @param n_columns If multiple models are being plotted, how many columns in the plot grid to use.
#' @param color The color of the bars in the plot.
#' @keywords plot
#' @import ggplot2 cowplot purrr
#' @export
#' @examples
#' plot_top_contributors(shap_table, name, 10)
plot_top_contributors <- function(models,
                                  models_to_use = NULL,
                                  data_to_use = "training", 
                                  n_predictors = 10, n_columns = 5, 
                                  color = "#52565e",
                                  labels_data = NULL, sec_label = NULL){
  
  # Restrict the plot to the top 10 most variable predictions across our clusters
  models_of_interest <- models
  if(!is.null(models_to_use) && length(models_to_use) > 0){
    models_of_interest <- models[models_to_use]
  }
  
  if (data_to_use == "training"){
    shap_tables <- map(models_of_interest, "feature_contribution")
  } else {
    shap_tables <- map(models_of_interest, "new_data") %>% map("feature_contribution")
  }

  p <- plot_grid(plotlist = imap(shap_tables, helper_plot_top_contributors, n = n_predictors, color = color, labels_data = labels_data, sec_label = sec_label), align='v', ncol=n_columns)

  return(p)
  
}


#' Overlay predictions on a dimensional reduction plot
#'
#' This function generates a dimensional reduction plot with predictions overlayed on each cell.
#' @param scRNA_data A Seurat object with predictions appended (e.g. using append_predictions_to_seurat)
#' @param perturbation The name of the perturbation whose predictions we want to plot.
#' @param reduction The dimensionality reduction technique to show (default = 'umap')
#' @param dims A 2-element vector with the dimensions to plot (default = c(1,2))
#' @param fixed_color_scale Default = FALSE. If TRUE, the color scale will be fixed from 0 to 1.
#' @keywords plot dimplot
#' @import Seurat ggplot2 cowplot
#' @export
#' @examples
#' plot_predictions_dimplot(my_seurat_obj, "ko_ctnnb1")
plot_predictions_dimplot <- function(scRNA_data, perturbation, group_by = NULL,
                                        reduction = "umap", pt_size = 0.25, show_labels = F, label_size = 3,
                                        dims = c(1,2),
                                        fixed_color_scale = FALSE){
  
  feature_data <- scRNA_data@meta.data %>% select(all_of(perturbation)) %>% as_tibble(rownames = "cell_id")
  colnames(feature_data) <- c("cell_id","feature_data")
  
  if(!is.null(group_by)){
    cluster_data <- scRNA_data@meta.data %>% select(all_of(group_by)) %>% as_tibble(rownames = "cell_id")
    feature_data <- feature_data %>% left_join(cluster_data, by = "cell_id")
    colnames(feature_data) <- c("cell_id","feature_data","cluster_data")
  } 
  
  embed_data <- scRNA_data[[reduction]]@cell.embeddings %>% as_tibble(rownames = "cell_id")
  
  
  cell_ids <- embed_data[,1]
  dim_1 <- embed_data[,dims[1]+1]
  dim_2 <- embed_data[,dims[2]+1]
  
  
  dim1_name = colnames(embed_data)[dims[1]+1]
  dim2_name = colnames(embed_data)[dims[2]+1]
  
  
  embed_data <- tibble(cell_ids,dim_1,dim_2)
  colnames(embed_data) <- c("cell_id",dim1_name,dim2_name)
  
  
  plot_data <- embed_data %>% left_join(feature_data, by = "cell_id")
  
  
  
  if(!is.null(group_by)){
    label_data <- plot_data %>% select(all_of(c(dim1_name,dim2_name,"cluster_data")))
    
    label_data <- label_data %>% group_by(cluster_data) %>% summarize(mean_x = median(get(dim1_name)),
                                                                      mean_y = median(get(dim2_name)))   
    
  }
  
  
  
  
  p <- ggplot()+
    geom_point(data = plot_data, aes(x=get(dim1_name),y=get(dim2_name),color=feature_data), size = pt_size)
  
  
  if(!is.null(group_by) && show_labels){
    p <- p + 
      geom_text(aes(x=label_data$mean_x,y=label_data$mean_y,label=label_data$cluster_data),size=label_size)
  }
  
  if(fixed_color_scale){
    p <- p + scale_color_distiller(palette = "RdYlBu", direction = -1, limits = c(0, 1))
  } else {
    p <- p + scale_color_distiller(palette = "RdYlBu", direction = -1)
  }
  
  
  p <- p + theme_bw(base_line_size = 0)+
    labs(color = perturbation, x = colnames(embed_data)[dims[1]], y = colnames(embed_data)[dims[2]])
  
  return(p)
  
}



#' Plot feature contributions to sample predictions
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions.
#' @param model A model generated with make_xgb_models and has appended predictions with add_predictions.
#' @param model_data The training dataset used to generate the model.
#' @param name The name of the perturbation whose prediction we want to plot.
#' @param sample_names The names of the samples we want to generate a plot for.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @param highlight_significant If TRUE, a border will be shown to indicate the prediction interval is above 0.5. Default = FALSE.
#' @param plot_new_data Set to TRUE if plotting the contribution to a new sample (not used in training).
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_sample(my_models, model_dataset, "ko_ctnnb1", "my_sample")
plot_contribution_to_sample <- function(model, model_data, name, sample_names, 
                                        n_features = 5, 
                                        n_columns = 1,
                                        short_title = FALSE,
                                        fixed_axis = FALSE, axis_limits = c(-0.05,1),
                                        nudges_lr = c(0,0),
                                        values_are_percentages = TRUE, decreasing = FALSE,
                                        replace_names = FALSE,  sample_info = NULL,
                                        show_error = TRUE, 
                                        highlight_significant = FALSE,
                                        plot_new_data = FALSE, labels_data = NULL, sec_label = NULL){
  
  # This will hold the final list of plots
  p_list <- list()
  
  # The null prediction is just the average of all predictions in the training data
  
  null_prediction = model$null_prediction #  mean(model$predictions)
  
  feature_data <- get_original_data(model,model_data)
  
  for (sample_name in sample_names){
    
    
    # We store the final prediction value
    if(!plot_new_data) final_prediction = model$predictions[sample_name] else final_prediction = model$new_data$predictions[sample_name]
    
    # Estimate error
    if(!plot_new_data) prediction_error <- model$predictions_error[sample_name] else prediction_error <- model$new_data$predictions_error[sample_name]
    error_upper <- final_prediction + 1.96*prediction_error
    error_lower <- final_prediction - 1.96*prediction_error
    
    
    # Create a long table of terms and their shapley values (called shap_value)
    if(!plot_new_data) term_values <- model$shap_values[sample_name,] else term_values <- model$new_data$shap_values[sample_name,] 
    
    term_values <- term_values %>% data.table::transpose(keep.names = "term") %>% 
      rename("shap_value" = "V1") %>% arrange(desc(abs(shap_value))) %>% filter(abs(shap_value) > 0)
    
    
    
    # Add the expression level of each term (called feature_value)
    if(!plot_new_data){
      feature_values <- feature_data[sample_name,]
    } else {
      feature_values <- model$new_data$data[sample_name,]
    }
    
    
    feature_values <- feature_values %>% t %>% as_tibble(rownames = "term")
    colnames(feature_values) <- c("term","feature_value")
    feature_values <- feature_values %>% mutate(feature_value = round(feature_value, 2))
    term_values <- term_values %>% left_join(feature_values, by = "term")
    
    
    # Keep the top terms, and combine the contribution of all other terms (we also count how many terms we combined)
    top_terms <- term_values %>% top_n(n_features, wt=abs(shap_value)) %>% pull(term)
    other_count <- term_values %>% filter(!term %in% top_terms) %>% pull(shap_value) %>% length()
    other_sum <- term_values %>% filter(!term %in% top_terms) %>% pull(shap_value) %>% sum()
    if (other_count == 1) term_word = "term" else term_word = "terms"
    other_row <- tibble(term = glue::glue("{other_count} other {term_word}"), shap_value = other_sum)
    none_row <- tibble(term = glue::glue("starting_value"), shap_value = null_prediction)
    
    # We generate all the information we need to make the plot
    term_values <- term_values %>% filter(term %in% top_terms) %>% bind_rows(other_row) %>% bind_rows(none_row) %>%
      rowid_to_column("id") %>% arrange(desc(id)) %>% 
      mutate(end = cumsum(shap_value)) %>% mutate(start = lag(end, default = 0)) %>%
      filter(term != "starting_value") %>%
      mutate(term = factor(term, levels=term)) %>% mutate(id = seq_along(term)) %>%
      mutate(term_label = if_else(!is.na(feature_value),glue::glue("{str_to_upper(term)} = {feature_value}"),str_to_upper(term))) %>%
      mutate(term_label = factor(term_label, levels=term_label)) %>%
      mutate(term_direction = if_else(shap_value > 0, "positive", "negative")) %>%
      mutate(shap_label = scales::percent(shap_value,accuracy=0.01))
    
    # Revert the percentage label
    if (!values_are_percentages)  term_values <- term_values %>% mutate(shap_label = as.character(round(shap_value,2)))

    
    # Use alternative names of provided as labels_data
    if (!is.null(labels_data)){
      old_name <- name
      new_name <- labels_data %>% filter(old_label == old_name) %>% pull(new_label) %>% first()
      if (!is.null(sec_label)) second_label = labels_data  %>% filter(old_label == old_name) %>% pull(second_label) %>% first()
    } else {
      new_name <- word(name,2,sep="_")
      second_label <- NULL
    }
    
    clean_perturb_name = str_to_upper(new_name)
    
    if(replace_names) sample_name <- get_cell_line_name(sample_name, sample_info)
    
    # Make the plot    
    p <- ggplot(term_values, aes(x = term_label, fill = term_direction)) + 
      geom_point(alpha=0, y=0)+
      geom_hline(yintercept = first(term_values$start), linetype = "dashed", color = "lightgray")+
      geom_hline(yintercept = last(term_values$end), linetype = "dashed", color = "black")
    
    
    if (show_error){
      p <- p +
        geom_rect(xmin = -Inf, xmax = Inf, ymin = last(term_values$end) - 1.96*prediction_error, ymax = last(term_values$end) + 1.96*prediction_error, alpha = 0.025, fill="gray")+
        geom_hline(yintercept = last(term_values$end) + 1.96*prediction_error, linetype = "dashed", color = "lightgray")+
        geom_hline(yintercept = last(term_values$end) - 1.96*prediction_error, linetype = "dashed", color = "lightgray")
      
      
    } 
    
    plot_title = glue::glue("{sample_name} ({scales::percent(final_prediction,accuracy = 1)} probability of {clean_perturb_name} dependency)")
    
    if(!is.null(sec_label)) plot_title = glue::glue("{sample_name} ({clean_perturb_name})")
    
    
    if(short_title)  plot_title = glue::glue("{sample_name} ({scales::percent(final_prediction,accuracy = 1)} {clean_perturb_name})")
    
    # Reverse colors if decreasing scale
    if (!decreasing){
      negative_color = "#3591d1"
      positive_color = "#f04546"
    } else {
      negative_color = "#f04546"
      positive_color = "#3591d1"
    }
    
    p <- p + geom_rect(aes(xmin = id - 0.45, xmax = id + 0.45, ymin = start, ymax = end))+
      geom_segment(aes(x = id, xend = id, y = start, yend = end), lineend = "round", linejoin ="round", size = 0.7, arrow = arrow(angle = 30, length = unit(0.5, "cm")), color = "white")+
      geom_text(aes(x = id,
                    y = end, 
                    color = if_else(shap_value > 0, "positive","negative"), 
                    label = if_else(shap_value > 0, paste0("+",shap_label),shap_label)), 
                nudge_y = if_else(term_values$shap_value > 0, 0.04 + nudges_lr[2], -0.04 + nudges_lr[1]),
                size = 3, hjust = 0.5)+
      scale_color_manual(values = c("negative" = negative_color, "positive" = positive_color))+
      scale_fill_manual(values = c("negative" = negative_color, "positive" = positive_color))+
      labs(title = plot_title, subtitle = second_label,
           x = NULL, y = "Outcome")+
      theme_bw()+
      theme(text=element_text(size=14), legend.position="none")
    
    if (!fixed_axis) p <- p+coord_flip(ylim = c(min(c(term_values$start, term_values$end))-0.05,0.05+max(c(term_values$start, term_values$end))))
    else p <- p+coord_flip(ylim = axis_limits)
    
    if (values_are_percentages) p <- p + scale_y_continuous(labels = scales::percent)
    
    if (highlight_significant && error_lower >= 0.5) p <- p +  theme(panel.border = element_rect(color = "red", fill = NA, size = 1))
    
    p_list[[sample_name]] <- p
    
  }
  
  p_final <- plot_grid(plotlist = p_list, ncol=n_columns, align='v')
  
  return(p_final)
  
}


#' Plot feature contributions to sample predictions from the training set.
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_training_sample(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, lineage_to_use = "soft_tissue")
plot_contribution_to_training_samples <- function(models, models_to_use, model_data, 
                                             samples_to_use = NULL, lineage_to_use = NULL, 
                                             n_features = 5, n_columns = 1, 
                                             fixed_axis = TRUE, axis_limits = c(-0.05,1), 
                                             nudges_lr = c(0,0),
                                             show_error = TRUE,
                                             highlight_significant = FALSE,
                                             replace_names = FALSE, sample_info = NULL,
                                             labels_data = NULL, sec_label = NULL,
                                             values_are_percentages = TRUE, decreasing = FALSE){
  
  
  # Restrict the list of models to only those we wish to plot
  demo_models <- models[models_to_use]
  
  # Restrict the sample list to only those we wish to plot
  samples_to_plot <- map(demo_models, get_demo_samples, samples = samples_to_use, lineage = lineage_to_use, model_data = model_data)

  # Generate a list of inputs
  inputs <- list()
  inputs$model <- demo_models
  inputs$name <- names(demo_models)
  inputs$sample_names <- samples_to_plot
  
  # Iterate through each input and add the plot to the list
  pl <- pmap(inputs, plot_contribution_to_sample, model_data = model_data, 
             n_features = n_features, n_columns = n_columns,
             fixed_axis = fixed_axis, axis_limits = axis_limits, nudges_lr = nudges_lr,
             values_are_percentages = values_are_percentages, decreasing = decreasing,
             replace_names = replace_names, sample_info = sample_info, 
             highlight_significant = highlight_significant,
             show_error = show_error, labels_data = labels_data, sec_label = sec_label)
  
  return(pl)
}





#' Plot feature contributions to new sample predictions
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions, using new samples.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_sample_demo(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, "soft_tissue")
plot_contribution_to_new_samples <- function(models, models_to_use, model_data, 
                                             new_samples_to_use = NULL,
                                             n_features = 5, n_columns = 4, 
                                             fixed_axis = TRUE, axis_limits = c(-0.05,1),
                                             nudges_lr = c(0,0), replace_names = FALSE,
                                             show_error = TRUE, values_are_percentages = TRUE, decreasing = FALSE,
                                             highlight_significant = TRUE, short_title = TRUE,
                                             sample_info = NULL, labels_data = NULL, sec_label = NULL){
  
  # Restrict the list of models to only those we wish to plot
  demo_models <- models[models_to_use]
  
  # Restrict the sample list to only those we wish to plot
  samples_to_plot <- map(demo_models, "new_data") %>% 
    map("data") %>% 
    map(rownames)
  
  if(!is.null(new_samples_to_use) && length(new_samples_to_use) > 0) samples_to_plot <- samples_to_plot %>% map(intersect, new_samples_to_use)
  
  # Generate a list of inputs
  inputs <- list()
  inputs$model <- demo_models
  inputs$name <- names(demo_models)
  inputs$sample_names <- samples_to_plot
  
  # Iterate through each input and add the plot to the list
  pl <- pmap(inputs, plot_contribution_to_sample, model_data = model_data, 
             n_features = n_features, n_columns = n_columns, 
             fixed_axis = fixed_axis, axis_limits = axis_limits, nudges_lr = nudges_lr,
             replace_names = replace_names, sample_info = sample_info, 
             highlight_significant = highlight_significant, short_title = short_title,
             show_error = show_error, values_are_percentages = values_are_percentages, decreasing = decreasing,
             labels_data = labels_data, sec_label = sec_label, plot_new_data = TRUE) 
  
  return(pl)
  
}


  


#' Plot Shapley values vs feature values
#'
#' This function plots a scatter of the contribution (Shapley values) of a feature against the value of that feature for each sample.
#' @param model A model generated with make_xgb_models and has appended predictions with add_predictions.
#' @param model_data The training dataset used to generate the model.
#' @param name The name of the perturbation whose prediction we want to plot.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param sample_names The names of the samples we want to highlight.
#' @param sample_colors If highlighting samples (with sample_names), this is a vector of colors to use.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter(my_models, model_dataset, "ko_ctnnb1", "my_sample", "red")
plot_shap_scatter <- function(model, name, model_data,
                              n_features = 4, 
                              n_columns = 2,
                              overlay_predictions = FALSE,
                              sample_names = NULL, remove_prefix = TRUE,
                              sample_colors = NULL,
                              sample_info = NULL){     
  p <- list()
  
  if(!overlay_predictions){
    top_terms <- model$feature_contribution %>% top_n(n_features, wt=value) %>% pull(term)
  } else {
    top_terms <- model$new_data$feature_contribution %>% top_n(n_features, wt=value) %>% pull(term)
  }
  
  feature_data <- get_original_data(model,model_data)
  
  for (term in top_terms){
    
    # Prepare data
    df <- tibble(sample = rownames(model$shap_values),
                 x = feature_data[,term],
                 y = model$shap_values[,term],
                 y_value = feature_data[,"y_value"],
                 source = "model")
    
    if(!is.null(sample_names)){
      
      if(length(sample_colors) != length(sample_names)) sample_colors = gg_color_hue(length(sample_names))
      
      selection_df <- tibble(sample = sample_names, pt_color = sample_colors)
      
    } 
    
    if (overlay_predictions){
      
      df_pred <- tibble(sample = rownames(model$new_data$shap_values),
                        x = model$new_data$data[,term],
                        y = model$new_data$shap_values[,term],
                        sample_label = if_else(rep(remove_prefix,length(rownames(model$new_data$shap_values))), word(rownames(model$new_data$shap_values),2,sep="_"), rownames(model$new_data$shap_values))  ,
                        source = "prediction")
      
      if(!is.null(sample_names)) {
        df_pred <- df_pred %>% filter(sample %in% sample_names)
        df_pred <- df_pred %>% left_join(selection_df, by = "sample")
      }
      
    } else {
      
      if(!is.null(sample_names)) df <- df %>% left_join(selection_df, by = "sample")
      
    }
    
    # Make a dummy plot to grab the color legend
    color_legend = NULL
    if(!is.null(sample_names)) {
      
      legend_p <- ggplot(selection_df, aes(color=sample))+
        geom_point(x=1,y=1)
      
      if(!overlay_predictions){
        if(!is.null(sample_info)){
          legend_p <- legend_p + scale_color_manual(values = sample_colors, 
                                                    labels = get_cell_line_name(selection_df$sample, sample_info = sample_info))
        } else {
          legend_p <- legend_p + scale_color_manual(values = sample_colors, 
                                                    labels = selection_df$sample)
        }
      }
      
      legend_p <- legend_p +
        guides(color = guide_legend(title = if_else(overlay_predictions, "Sample", "Cell Line"), 
                                    ncol = 8, 
                                    title.position = "left"))
      
      color_legend = get_legend(legend_p)
      
    }
    
    term_class = feature_data[[term]] %>% class
    
    # Make the real plot
    p[[term]] <- ggplot(df, aes(x=x,y=y,color=y_value))+
      geom_point(size = 1, alpha = 0.5)
    
    if(term_class == 'numeric')  p[[term]] <-  p[[term]] + geom_smooth(formula = y ~ x, method = "loess", se=F,size=1, color="black")
    
    
    if(overlay_predictions){
      if(!is.null(sample_names)) p[[term]] <- p[[term]] + geom_label(data = df_pred, x=df_pred$x, y=df_pred$y, label=df_pred$sample_label, color = df_pred$pt_color,size=2,na.rm=T)
      p[[term]] <- p[[term]] + coord_cartesian(xlim = c(min(c(df_pred$x,df$x)), max(c(df_pred$x,df$x))), ylim = c(min(c(df_pred$y, df$y)),max(c(df_pred$y,df$y))))
    } else {
      if(!is.null(sample_names)) p[[term]] <- p[[term]] +  geom_point(color=df$pt_color, size = 3, na.rm = T)
    }
    
    perturbation_label <- name %>% word(2, sep = "_") %>% str_to_upper()
    
    p[[term]] <- p[[term]] + scale_color_viridis_c() + 
      labs(x = str_to_upper(term), y = "Shapley Value", color = str_to_upper(term))+
      guides(color = guide_colorbar(title = glue::glue("Observed {perturbation_label} Dependency"),
                                    title.position = "left", 
                                    label.theme = element_text(size = 8, angle = 45, vjust = 1, hjust = 1),
                                    barheight = 0.5,
                                    direction = "horizontal",
                                    title.vjust = 1,
                                    label=T))
    
    # Grab the legend of the first plot and hide all legends
    if(match(term,top_terms)==1) gradient_legend = get_legend(p[[term]])
    
    p[[term]] <-  p[[term]] + theme(legend.position="none")
    
  } 
  
  p <- plot_grid(plotlist = p, ncol = n_columns, align = 'v')
  
  p_legend <- plot_grid(gradient_legend,color_legend,ncol = 1, align='v')
  
  # Add the legend at the bottom before returning
  p <- plot_grid(p,p_legend,ncol=1,rel_heights = c(8,2))
  
  return(p)
}


#' Plot SHAP vs Feature Value scatter for training data.
#'
#' This function plots the contribution (Shapley values) against the value of a feature.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param samples_to_use Optional to highlight specific samples.
#' @param lineage_to_use Optional to highlight specific samples of certain lineage.
#' @param sample_colors Color vector to use for samples_to_use.
#' @keywords plot shapley contribution scatter
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter_for_training_samples(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, samples_to_use = "my_sample")
plot_shap_scatter_for_training_samples <- function(models, models_to_use, model_data, 
                                          samples_to_use = NULL, lineage_to_use = NULL, sample_colors = NULL,
                                          n_features = 6, n_columns = 3, sample_info = NULL){
  
  # Restrict the list of models to only those we wish to plot
  demo_models <- models[models_to_use]
  
  # Restrict the sample list to only those we wish to plot
  samples_to_plot <- map(demo_models, get_demo_samples, samples = samples_to_use, lineage = lineage_to_use, model_data = model_data)
  
  # Generate a list of inputs
  inputs <- list()
  inputs$model <- demo_models
  inputs$name <- names(demo_models)
  inputs$sample_names <- samples_to_plot
  
  # Iterate through each input and add the plot to the list
  pl <- pmap(inputs, plot_shap_scatter,
             model_data = model_data,
             n_features = n_features, n_columns = n_columns, 
             sample_colors = sample_colors, sample_info = sample_info)
  
  return(pl)
  
}




#' Plot SHAP vs Feature Value scatter for new data.
#'
#' This function plots the contribution (Shapley values) against the value of a feature.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param samples_to_use Optional to highlight specific samples.
#' @param lineage_to_use Optional to highlight specific samples of certain lineage.
#' @param sample_colors Color vector to use for samples_to_use.
#' @keywords plot shapley contribution scatter
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter_for_new_samples(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, samples_to_use = "my_sample")
plot_shap_scatter_for_new_samples <- function(models, models_to_use, model_data, 
                                              new_samples_to_use = NULL, remove_prefix = TRUE, sample_colors = NULL,
                                              n_features = 6, n_columns = 3){
  
  # Restrict the list of models to only those we wish to plot
  demo_models <- models[models_to_use]
  
  # Restrict the sample list to only those we wish to plot
  samples_to_plot <- map(demo_models, "new_data") %>%
    map("data") %>%
    map(rownames)
  
  if(!is.null(new_samples_to_use) && length(new_samples_to_use) > 0) samples_to_plot <- samples_to_plot %>% map(intersect, new_samples_to_use)
  
  # Generate a list of inputs
  inputs <- list()
  inputs$model <- demo_models
  inputs$name <- names(demo_models)
  inputs$sample_names <- samples_to_plot
  
  # Iterate through each input and add the plot to the list
  pl <- pmap(inputs, plot_shap_scatter,
             model_data = model_data,
             n_features = n_features, n_columns = n_columns, 
             sample_colors = sample_colors, overlay_predictions = TRUE, remove_prefix = remove_prefix)
  
  return(pl)
  
}
Mushriq/mixmap documentation built on Jan. 28, 2024, 7:22 p.m.