R/GeneModuleClassify.R

#'#' GeneModuleClassify
#'
#' This function retrieves gene module classifications using cell module score columns in seurat object.
#'
#' @param seurat_obj Seurat object
#' @param gene_modules List of gene modules generated by antler package
#' @param metadata metadata slot to group cells by e.g. stage, scHelper_cell_state, run
#' @param plot_path directory to plot log files (boxplots)
#' @return ordered dataframe of gene modules and their classifications
#' @export

GeneModuleClassify <- function (seurat_obj, gene_modules, 
                             metadata = NULL, plot_path = "scHelper_log/GM_classification/") 
{
  # Classify and plot gms
  classified_gms <- lapply(gene_modules, function(x) t(GetAssayData(object = seurat_obj, assay = 'RNA', slot = 'scale.data'))[,x] %>% rowSums(.)) %>%
    do.call('cbind', .) %>%
    merge(., seurat_obj@meta.data[,metadata, drop=FALSE], by=0, all=TRUE)
  
  classified_gms <- classified_gms %>% column_to_rownames('Row.names') %>%
    pivot_longer(cols = !(!!metadata)) %>%
    rename(gene_module = name) %>%
    group_by(gene_module, !!sym(metadata))
  
  ncol = round(30/length(unique(classified_gms[[metadata]])))
  nrow = length(unique(classified_gms$gene_module))/ncol
  
  dir.create(plot_path, recursive = TRUE)
  
  png(paste0(plot_path, metadata, ".png"), width = 40,
      height = 30, units = "cm", res = 200)
  p = ggplot(classified_gms, aes(x = !!sym(metadata), y = value, fill = !!sym(metadata))) +
    geom_boxplot() +
    scale_fill_manual(values = colorRampPalette(brewer.pal(8, "Dark2"))(nrow(unique(classified_gms[,1])))) +
    facet_wrap(~gene_module, ncol = ncol) +
    xlab(NULL) +
    ylab('Average scaled expression') +
    theme(legend.position = "none",
          panel.background = element_blank(),
          axis.line = element_line(colour = "black"),
          strip.background = element_rect(colour = "white", fill = "white"),
          strip.text = element_text(size = 10),
          axis.text = element_text(angle = 90, hjust = 1, vjust = 0.5),
          axis.title = element_text(size = 10)) +
    annotate("segment", x=-Inf, xend=Inf, y=-Inf, yend=-Inf)+
    annotate("segment", x=-Inf, xend=-Inf, y=-Inf, yend=Inf)
  print(p)
  graphics.off()
  
  # Select top classification and order
  classified_gms <- classified_gms %>%
    summarise(value = mean(value), .groups = 'keep') %>%
    group_by(gene_module) %>%
    filter(value == max(value)) %>%
    dplyr::select(!value)

  return(classified_gms)
}
alexthiery/scHelper documentation built on Aug. 26, 2023, 3:42 p.m.