#' Plot top contributions (internal function)
#'
#' This function generates a plot of the top contributing features for the top n models.
#' @keywords plot
#' @import ggplot2 cowplot purrr
#' @examples
#' plot_top_contributors(shap_table, name, 10)
helper_plot_top_contributors <- function(shap_table, name, n, color = "#52565e",labels_data = NULL, sec_label = NULL){
shap_table <- shap_table %>% top_n(n, wt=value)
shap_table$term <- factor(str_to_upper(shap_table$term), levels = rev(str_to_upper(shap_table$term)))
shap_table$labels <- abbreviate(as.character(shap_table$term), minlength = 10, dot = T)
# Use alternative names of provided as labels_data
if (!is.null(labels_data)){
old_name <- name
name <- labels_data %>% filter(old_label == old_name) %>% pull(new_label) %>% first()
if (!is.null(sec_label)) second_label = labels_data %>% filter(old_label == old_name) %>% pull(second_label) %>% first()
} else {
name <- word(name,2,sep="_")
second_label <- NULL
}
p <- ggplot(shap_table, aes(x=term,y=value))+geom_bar(stat="identity", fill = color)+coord_flip()+
labs(x="Feature",y="Contribution",title=str_to_upper(name),subtitle=second_label)+
scale_x_discrete(labels = shap_table$labels)
return(p)
}
#' Plot top contributions
#'
#' This function generates a plot of the top contributing features for the top n models.
#' @param models A list of model objects generated by make_xgb_models
#' @param models_to_use A vector of model names to restrict the plot to.
#' @param data_to_use If "training" (default), it will use the full training data. If anything else, it will use the new predictions (must have used add_predictions before).
#' @param n_predictors The maximum number of predictors to show per plot.
#' @param n_columns If multiple models are being plotted, how many columns in the plot grid to use.
#' @param color The color of the bars in the plot.
#' @keywords plot
#' @import ggplot2 cowplot purrr
#' @export
#' @examples
#' plot_top_contributors(shap_table, name, 10)
plot_top_contributors <- function(models,
models_to_use = NULL,
data_to_use = "training",
n_predictors = 10, n_columns = 5,
color = "#52565e",
labels_data = NULL, sec_label = NULL){
# Restrict the plot to the top 10 most variable predictions across our clusters
models_of_interest <- models
if(!is.null(models_to_use) && length(models_to_use) > 0){
models_of_interest <- models[models_to_use]
}
if (data_to_use == "training"){
shap_tables <- map(models_of_interest, "feature_contribution")
} else {
shap_tables <- map(models_of_interest, "new_data") %>% map("feature_contribution")
}
p <- plot_grid(plotlist = imap(shap_tables, helper_plot_top_contributors, n = n_predictors, color = color, labels_data = labels_data, sec_label = sec_label), align='v', ncol=n_columns)
return(p)
}
#' Overlay predictions on a dimensional reduction plot
#'
#' This function generates a dimensional reduction plot with predictions overlayed on each cell.
#' @param scRNA_data A Seurat object with predictions appended (e.g. using append_predictions_to_seurat)
#' @param perturbation The name of the perturbation whose predictions we want to plot.
#' @param reduction The dimensionality reduction technique to show (default = 'umap')
#' @param dims A 2-element vector with the dimensions to plot (default = c(1,2))
#' @param fixed_color_scale Default = FALSE. If TRUE, the color scale will be fixed from 0 to 1.
#' @keywords plot dimplot
#' @import Seurat ggplot2 cowplot
#' @export
#' @examples
#' plot_predictions_dimplot(my_seurat_obj, "ko_ctnnb1")
plot_predictions_dimplot <- function(scRNA_data, perturbation, group_by = NULL,
reduction = "umap", pt_size = 0.25, show_labels = F, label_size = 3,
dims = c(1,2),
fixed_color_scale = FALSE){
feature_data <- scRNA_data@meta.data %>% select(all_of(perturbation)) %>% as_tibble(rownames = "cell_id")
colnames(feature_data) <- c("cell_id","feature_data")
if(!is.null(group_by)){
cluster_data <- scRNA_data@meta.data %>% select(all_of(group_by)) %>% as_tibble(rownames = "cell_id")
feature_data <- feature_data %>% left_join(cluster_data, by = "cell_id")
colnames(feature_data) <- c("cell_id","feature_data","cluster_data")
}
embed_data <- scRNA_data[[reduction]]@cell.embeddings %>% as_tibble(rownames = "cell_id")
cell_ids <- embed_data[,1]
dim_1 <- embed_data[,dims[1]+1]
dim_2 <- embed_data[,dims[2]+1]
dim1_name = colnames(embed_data)[dims[1]+1]
dim2_name = colnames(embed_data)[dims[2]+1]
embed_data <- tibble(cell_ids,dim_1,dim_2)
colnames(embed_data) <- c("cell_id",dim1_name,dim2_name)
plot_data <- embed_data %>% left_join(feature_data, by = "cell_id")
if(!is.null(group_by)){
label_data <- plot_data %>% select(all_of(c(dim1_name,dim2_name,"cluster_data")))
label_data <- label_data %>% group_by(cluster_data) %>% summarize(mean_x = median(get(dim1_name)),
mean_y = median(get(dim2_name)))
}
p <- ggplot()+
geom_point(data = plot_data, aes(x=get(dim1_name),y=get(dim2_name),color=feature_data), size = pt_size)
if(!is.null(group_by) && show_labels){
p <- p +
geom_text(aes(x=label_data$mean_x,y=label_data$mean_y,label=label_data$cluster_data),size=label_size)
}
if(fixed_color_scale){
p <- p + scale_color_distiller(palette = "RdYlBu", direction = -1, limits = c(0, 1))
} else {
p <- p + scale_color_distiller(palette = "RdYlBu", direction = -1)
}
p <- p + theme_bw(base_line_size = 0)+
labs(color = perturbation, x = colnames(embed_data)[dims[1]], y = colnames(embed_data)[dims[2]])
return(p)
}
#' Plot feature contributions to sample predictions
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions.
#' @param model A model generated with make_xgb_models and has appended predictions with add_predictions.
#' @param model_data The training dataset used to generate the model.
#' @param name The name of the perturbation whose prediction we want to plot.
#' @param sample_names The names of the samples we want to generate a plot for.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @param highlight_significant If TRUE, a border will be shown to indicate the prediction interval is above 0.5. Default = FALSE.
#' @param plot_new_data Set to TRUE if plotting the contribution to a new sample (not used in training).
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_sample(my_models, model_dataset, "ko_ctnnb1", "my_sample")
plot_contribution_to_sample <- function(model, model_data, name, sample_names,
n_features = 5,
n_columns = 1,
short_title = FALSE,
fixed_axis = FALSE, axis_limits = c(-0.05,1),
nudges_lr = c(0,0),
values_are_percentages = TRUE, decreasing = FALSE,
replace_names = FALSE, sample_info = NULL,
show_error = TRUE,
highlight_significant = FALSE,
plot_new_data = FALSE, labels_data = NULL, sec_label = NULL){
# This will hold the final list of plots
p_list <- list()
# The null prediction is just the average of all predictions in the training data
null_prediction = model$null_prediction # mean(model$predictions)
feature_data <- get_original_data(model,model_data)
for (sample_name in sample_names){
# We store the final prediction value
if(!plot_new_data) final_prediction = model$predictions[sample_name] else final_prediction = model$new_data$predictions[sample_name]
# Estimate error
if(!plot_new_data) prediction_error <- model$predictions_error[sample_name] else prediction_error <- model$new_data$predictions_error[sample_name]
error_upper <- final_prediction + 1.96*prediction_error
error_lower <- final_prediction - 1.96*prediction_error
# Create a long table of terms and their shapley values (called shap_value)
if(!plot_new_data) term_values <- model$shap_values[sample_name,] else term_values <- model$new_data$shap_values[sample_name,]
term_values <- term_values %>% data.table::transpose(keep.names = "term") %>%
rename("shap_value" = "V1") %>% arrange(desc(abs(shap_value))) %>% filter(abs(shap_value) > 0)
# Add the expression level of each term (called feature_value)
if(!plot_new_data){
feature_values <- feature_data[sample_name,]
} else {
feature_values <- model$new_data$data[sample_name,]
}
feature_values <- feature_values %>% t %>% as_tibble(rownames = "term")
colnames(feature_values) <- c("term","feature_value")
feature_values <- feature_values %>% mutate(feature_value = round(feature_value, 2))
term_values <- term_values %>% left_join(feature_values, by = "term")
# Keep the top terms, and combine the contribution of all other terms (we also count how many terms we combined)
top_terms <- term_values %>% top_n(n_features, wt=abs(shap_value)) %>% pull(term)
other_count <- term_values %>% filter(!term %in% top_terms) %>% pull(shap_value) %>% length()
other_sum <- term_values %>% filter(!term %in% top_terms) %>% pull(shap_value) %>% sum()
if (other_count == 1) term_word = "term" else term_word = "terms"
other_row <- tibble(term = glue::glue("{other_count} other {term_word}"), shap_value = other_sum)
none_row <- tibble(term = glue::glue("starting_value"), shap_value = null_prediction)
# We generate all the information we need to make the plot
term_values <- term_values %>% filter(term %in% top_terms) %>% bind_rows(other_row) %>% bind_rows(none_row) %>%
rowid_to_column("id") %>% arrange(desc(id)) %>%
mutate(end = cumsum(shap_value)) %>% mutate(start = lag(end, default = 0)) %>%
filter(term != "starting_value") %>%
mutate(term = factor(term, levels=term)) %>% mutate(id = seq_along(term)) %>%
mutate(term_label = if_else(!is.na(feature_value),glue::glue("{str_to_upper(term)} = {feature_value}"),str_to_upper(term))) %>%
mutate(term_label = factor(term_label, levels=term_label)) %>%
mutate(term_direction = if_else(shap_value > 0, "positive", "negative")) %>%
mutate(shap_label = scales::percent(shap_value,accuracy=0.01))
# Revert the percentage label
if (!values_are_percentages) term_values <- term_values %>% mutate(shap_label = as.character(round(shap_value,2)))
# Use alternative names of provided as labels_data
if (!is.null(labels_data)){
old_name <- name
new_name <- labels_data %>% filter(old_label == old_name) %>% pull(new_label) %>% first()
if (!is.null(sec_label)) second_label = labels_data %>% filter(old_label == old_name) %>% pull(second_label) %>% first()
} else {
new_name <- word(name,2,sep="_")
second_label <- NULL
}
clean_perturb_name = str_to_upper(new_name)
if(replace_names) sample_name <- get_cell_line_name(sample_name, sample_info)
# Make the plot
p <- ggplot(term_values, aes(x = term_label, fill = term_direction)) +
geom_point(alpha=0, y=0)+
geom_hline(yintercept = first(term_values$start), linetype = "dashed", color = "lightgray")+
geom_hline(yintercept = last(term_values$end), linetype = "dashed", color = "black")
if (show_error){
p <- p +
geom_rect(xmin = -Inf, xmax = Inf, ymin = last(term_values$end) - 1.96*prediction_error, ymax = last(term_values$end) + 1.96*prediction_error, alpha = 0.025, fill="gray")+
geom_hline(yintercept = last(term_values$end) + 1.96*prediction_error, linetype = "dashed", color = "lightgray")+
geom_hline(yintercept = last(term_values$end) - 1.96*prediction_error, linetype = "dashed", color = "lightgray")
}
plot_title = glue::glue("{sample_name} ({scales::percent(final_prediction,accuracy = 1)} probability of {clean_perturb_name} dependency)")
if(!is.null(sec_label)) plot_title = glue::glue("{sample_name} ({clean_perturb_name})")
if(short_title) plot_title = glue::glue("{sample_name} ({scales::percent(final_prediction,accuracy = 1)} {clean_perturb_name})")
# Reverse colors if decreasing scale
if (!decreasing){
negative_color = "#3591d1"
positive_color = "#f04546"
} else {
negative_color = "#f04546"
positive_color = "#3591d1"
}
p <- p + geom_rect(aes(xmin = id - 0.45, xmax = id + 0.45, ymin = start, ymax = end))+
geom_segment(aes(x = id, xend = id, y = start, yend = end), lineend = "round", linejoin ="round", size = 0.7, arrow = arrow(angle = 30, length = unit(0.5, "cm")), color = "white")+
geom_text(aes(x = id,
y = end,
color = if_else(shap_value > 0, "positive","negative"),
label = if_else(shap_value > 0, paste0("+",shap_label),shap_label)),
nudge_y = if_else(term_values$shap_value > 0, 0.04 + nudges_lr[2], -0.04 + nudges_lr[1]),
size = 3, hjust = 0.5)+
scale_color_manual(values = c("negative" = negative_color, "positive" = positive_color))+
scale_fill_manual(values = c("negative" = negative_color, "positive" = positive_color))+
labs(title = plot_title, subtitle = second_label,
x = NULL, y = "Outcome")+
theme_bw()+
theme(text=element_text(size=14), legend.position="none")
if (!fixed_axis) p <- p+coord_flip(ylim = c(min(c(term_values$start, term_values$end))-0.05,0.05+max(c(term_values$start, term_values$end))))
else p <- p+coord_flip(ylim = axis_limits)
if (values_are_percentages) p <- p + scale_y_continuous(labels = scales::percent)
if (highlight_significant && error_lower >= 0.5) p <- p + theme(panel.border = element_rect(color = "red", fill = NA, size = 1))
p_list[[sample_name]] <- p
}
p_final <- plot_grid(plotlist = p_list, ncol=n_columns, align='v')
return(p_final)
}
#' Plot feature contributions to sample predictions from the training set.
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_training_sample(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, lineage_to_use = "soft_tissue")
plot_contribution_to_training_samples <- function(models, models_to_use, model_data,
samples_to_use = NULL, lineage_to_use = NULL,
n_features = 5, n_columns = 1,
fixed_axis = TRUE, axis_limits = c(-0.05,1),
nudges_lr = c(0,0),
show_error = TRUE,
highlight_significant = FALSE,
replace_names = FALSE, sample_info = NULL,
labels_data = NULL, sec_label = NULL,
values_are_percentages = TRUE, decreasing = FALSE){
# Restrict the list of models to only those we wish to plot
demo_models <- models[models_to_use]
# Restrict the sample list to only those we wish to plot
samples_to_plot <- map(demo_models, get_demo_samples, samples = samples_to_use, lineage = lineage_to_use, model_data = model_data)
# Generate a list of inputs
inputs <- list()
inputs$model <- demo_models
inputs$name <- names(demo_models)
inputs$sample_names <- samples_to_plot
# Iterate through each input and add the plot to the list
pl <- pmap(inputs, plot_contribution_to_sample, model_data = model_data,
n_features = n_features, n_columns = n_columns,
fixed_axis = fixed_axis, axis_limits = axis_limits, nudges_lr = nudges_lr,
values_are_percentages = values_are_percentages, decreasing = decreasing,
replace_names = replace_names, sample_info = sample_info,
highlight_significant = highlight_significant,
show_error = show_error, labels_data = labels_data, sec_label = sec_label)
return(pl)
}
#' Plot feature contributions to new sample predictions
#'
#' This function plots the contribution (Shapley values) of the top features to a sample's predictions, using new samples.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param short_title If using many columns, set to TRUE to shorten the title of the plot.
#' @param fixed_axis If TRUE, the plot will be set to a fixed scale (-0.05 to 1). Default = FALSE.
#' @param replace_names If TRUE, the sample name will be replaced using get_cell_line_name. Must supply sample_info.
#' @param sample_info If replace_names is TRUE, this must be supplied.
#' @param show_error If TRUE, a shaded area will be used to visualize prediction interval. Default = TRUE.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_contribution_to_sample_demo(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, "soft_tissue")
plot_contribution_to_new_samples <- function(models, models_to_use, model_data,
new_samples_to_use = NULL,
n_features = 5, n_columns = 4,
fixed_axis = TRUE, axis_limits = c(-0.05,1),
nudges_lr = c(0,0), replace_names = FALSE,
show_error = TRUE, values_are_percentages = TRUE, decreasing = FALSE,
highlight_significant = TRUE, short_title = TRUE,
sample_info = NULL, labels_data = NULL, sec_label = NULL){
# Restrict the list of models to only those we wish to plot
demo_models <- models[models_to_use]
# Restrict the sample list to only those we wish to plot
samples_to_plot <- map(demo_models, "new_data") %>%
map("data") %>%
map(rownames)
if(!is.null(new_samples_to_use) && length(new_samples_to_use) > 0) samples_to_plot <- samples_to_plot %>% map(intersect, new_samples_to_use)
# Generate a list of inputs
inputs <- list()
inputs$model <- demo_models
inputs$name <- names(demo_models)
inputs$sample_names <- samples_to_plot
# Iterate through each input and add the plot to the list
pl <- pmap(inputs, plot_contribution_to_sample, model_data = model_data,
n_features = n_features, n_columns = n_columns,
fixed_axis = fixed_axis, axis_limits = axis_limits, nudges_lr = nudges_lr,
replace_names = replace_names, sample_info = sample_info,
highlight_significant = highlight_significant, short_title = short_title,
show_error = show_error, values_are_percentages = values_are_percentages, decreasing = decreasing,
labels_data = labels_data, sec_label = sec_label, plot_new_data = TRUE)
return(pl)
}
#' Plot Shapley values vs feature values
#'
#' This function plots a scatter of the contribution (Shapley values) of a feature against the value of that feature for each sample.
#' @param model A model generated with make_xgb_models and has appended predictions with add_predictions.
#' @param model_data The training dataset used to generate the model.
#' @param name The name of the perturbation whose prediction we want to plot.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param sample_names The names of the samples we want to highlight.
#' @param sample_colors If highlighting samples (with sample_names), this is a vector of colors to use.
#' @keywords plot shapley contribution
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter(my_models, model_dataset, "ko_ctnnb1", "my_sample", "red")
plot_shap_scatter <- function(model, name, model_data,
n_features = 4,
n_columns = 2,
overlay_predictions = FALSE,
sample_names = NULL, remove_prefix = TRUE,
sample_colors = NULL,
sample_info = NULL){
p <- list()
if(!overlay_predictions){
top_terms <- model$feature_contribution %>% top_n(n_features, wt=value) %>% pull(term)
} else {
top_terms <- model$new_data$feature_contribution %>% top_n(n_features, wt=value) %>% pull(term)
}
feature_data <- get_original_data(model,model_data)
for (term in top_terms){
# Prepare data
df <- tibble(sample = rownames(model$shap_values),
x = feature_data[,term],
y = model$shap_values[,term],
y_value = feature_data[,"y_value"],
source = "model")
if(!is.null(sample_names)){
if(length(sample_colors) != length(sample_names)) sample_colors = gg_color_hue(length(sample_names))
selection_df <- tibble(sample = sample_names, pt_color = sample_colors)
}
if (overlay_predictions){
df_pred <- tibble(sample = rownames(model$new_data$shap_values),
x = model$new_data$data[,term],
y = model$new_data$shap_values[,term],
sample_label = if_else(rep(remove_prefix,length(rownames(model$new_data$shap_values))), word(rownames(model$new_data$shap_values),2,sep="_"), rownames(model$new_data$shap_values)) ,
source = "prediction")
if(!is.null(sample_names)) {
df_pred <- df_pred %>% filter(sample %in% sample_names)
df_pred <- df_pred %>% left_join(selection_df, by = "sample")
}
} else {
if(!is.null(sample_names)) df <- df %>% left_join(selection_df, by = "sample")
}
# Make a dummy plot to grab the color legend
color_legend = NULL
if(!is.null(sample_names)) {
legend_p <- ggplot(selection_df, aes(color=sample))+
geom_point(x=1,y=1)
if(!overlay_predictions){
if(!is.null(sample_info)){
legend_p <- legend_p + scale_color_manual(values = sample_colors,
labels = get_cell_line_name(selection_df$sample, sample_info = sample_info))
} else {
legend_p <- legend_p + scale_color_manual(values = sample_colors,
labels = selection_df$sample)
}
}
legend_p <- legend_p +
guides(color = guide_legend(title = if_else(overlay_predictions, "Sample", "Cell Line"),
ncol = 8,
title.position = "left"))
color_legend = get_legend(legend_p)
}
term_class = feature_data[[term]] %>% class
# Make the real plot
p[[term]] <- ggplot(df, aes(x=x,y=y,color=y_value))+
geom_point(size = 1, alpha = 0.5)
if(term_class == 'numeric') p[[term]] <- p[[term]] + geom_smooth(formula = y ~ x, method = "loess", se=F,size=1, color="black")
if(overlay_predictions){
if(!is.null(sample_names)) p[[term]] <- p[[term]] + geom_label(data = df_pred, x=df_pred$x, y=df_pred$y, label=df_pred$sample_label, color = df_pred$pt_color,size=2,na.rm=T)
p[[term]] <- p[[term]] + coord_cartesian(xlim = c(min(c(df_pred$x,df$x)), max(c(df_pred$x,df$x))), ylim = c(min(c(df_pred$y, df$y)),max(c(df_pred$y,df$y))))
} else {
if(!is.null(sample_names)) p[[term]] <- p[[term]] + geom_point(color=df$pt_color, size = 3, na.rm = T)
}
perturbation_label <- name %>% word(2, sep = "_") %>% str_to_upper()
p[[term]] <- p[[term]] + scale_color_viridis_c() +
labs(x = str_to_upper(term), y = "Shapley Value", color = str_to_upper(term))+
guides(color = guide_colorbar(title = glue::glue("Observed {perturbation_label} Dependency"),
title.position = "left",
label.theme = element_text(size = 8, angle = 45, vjust = 1, hjust = 1),
barheight = 0.5,
direction = "horizontal",
title.vjust = 1,
label=T))
# Grab the legend of the first plot and hide all legends
if(match(term,top_terms)==1) gradient_legend = get_legend(p[[term]])
p[[term]] <- p[[term]] + theme(legend.position="none")
}
p <- plot_grid(plotlist = p, ncol = n_columns, align = 'v')
p_legend <- plot_grid(gradient_legend,color_legend,ncol = 1, align='v')
# Add the legend at the bottom before returning
p <- plot_grid(p,p_legend,ncol=1,rel_heights = c(8,2))
return(p)
}
#' Plot SHAP vs Feature Value scatter for training data.
#'
#' This function plots the contribution (Shapley values) against the value of a feature.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param samples_to_use Optional to highlight specific samples.
#' @param lineage_to_use Optional to highlight specific samples of certain lineage.
#' @param sample_colors Color vector to use for samples_to_use.
#' @keywords plot shapley contribution scatter
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter_for_training_samples(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, samples_to_use = "my_sample")
plot_shap_scatter_for_training_samples <- function(models, models_to_use, model_data,
samples_to_use = NULL, lineage_to_use = NULL, sample_colors = NULL,
n_features = 6, n_columns = 3, sample_info = NULL){
# Restrict the list of models to only those we wish to plot
demo_models <- models[models_to_use]
# Restrict the sample list to only those we wish to plot
samples_to_plot <- map(demo_models, get_demo_samples, samples = samples_to_use, lineage = lineage_to_use, model_data = model_data)
# Generate a list of inputs
inputs <- list()
inputs$model <- demo_models
inputs$name <- names(demo_models)
inputs$sample_names <- samples_to_plot
# Iterate through each input and add the plot to the list
pl <- pmap(inputs, plot_shap_scatter,
model_data = model_data,
n_features = n_features, n_columns = n_columns,
sample_colors = sample_colors, sample_info = sample_info)
return(pl)
}
#' Plot SHAP vs Feature Value scatter for new data.
#'
#' This function plots the contribution (Shapley values) against the value of a feature.
#' @param models A list of models generated with make_xgb_models and has appended predictions with add_predictions.
#' @param models_to_use A vector of model names to plot.
#' @param model_data The training dataset used to generate the model.
#' @param n_features The number of top contributors to show their individual contribution. All other predictors will have their contribution combined.
#' @param n_columns Number of columns to plot in a grid when plotting more than one sample.
#' @param samples_to_use Optional to highlight specific samples.
#' @param lineage_to_use Optional to highlight specific samples of certain lineage.
#' @param sample_colors Color vector to use for samples_to_use.
#' @keywords plot shapley contribution scatter
#' @import Seurat ggplot2 cowplot data.table
#' @export
#' @examples
#' plot_shap_scatter_for_new_samples(my_models, c("ko_ctnnb1","ko_myod1"), model_dataset, samples_to_use = "my_sample")
plot_shap_scatter_for_new_samples <- function(models, models_to_use, model_data,
new_samples_to_use = NULL, remove_prefix = TRUE, sample_colors = NULL,
n_features = 6, n_columns = 3){
# Restrict the list of models to only those we wish to plot
demo_models <- models[models_to_use]
# Restrict the sample list to only those we wish to plot
samples_to_plot <- map(demo_models, "new_data") %>%
map("data") %>%
map(rownames)
if(!is.null(new_samples_to_use) && length(new_samples_to_use) > 0) samples_to_plot <- samples_to_plot %>% map(intersect, new_samples_to_use)
# Generate a list of inputs
inputs <- list()
inputs$model <- demo_models
inputs$name <- names(demo_models)
inputs$sample_names <- samples_to_plot
# Iterate through each input and add the plot to the list
pl <- pmap(inputs, plot_shap_scatter,
model_data = model_data,
n_features = n_features, n_columns = n_columns,
sample_colors = sample_colors, overlay_predictions = TRUE, remove_prefix = remove_prefix)
return(pl)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.