R/plot_celltype_severity.R

Defines functions plot_celltype_severity

Documented in plot_celltype_severity

#' Plot cell type severity
#'
#' Plot the phenotype severity scores (generated by GPT-4) aggregated by
#' the cell types each phenotype is significantly associated with.
#' @param top_n Top and bottom number of cell types to show per annotation
#'  (used in dot plot only).
#' @param types Which types of plots of create.
#' @inheritParams prioritise_targets
#' @inheritParams plot_bar_dendro
#' @returns Named list of ggplot and data.table objects.
#'
#' @export
#' @examples
#' results <- load_example_results()
#' out <- plot_celltype_severity(results)
plot_celltype_severity <- function(results,
                                   cl = KGExplorer::get_ontology("cl")|>
                                     KGExplorer::filter_ontology(
                                       keep_descendants = "cell"
                                     ),
                                   q_threshold=.05,
                                   top_n=3,
                                   types=c("bar","dot")){

  requireNamespace("ggplot2")
  severity_score_gpt <- cl_name <- cl_id <- value <- variable <- NULL;

  results <- MSTExplorer::map_celltype(results)
  results_gpt <- HPOExplorer::add_gpt_annotations(results)
  celltypes_gpt <- (
    results_gpt[q<q_threshold,
                list(severity_score_gpt=mean(severity_score_gpt,na.rm=TRUE)),
                by=c("cl_id","cl_name")]|>
      data.table::setorderv("severity_score_gpt",-1, na.last = TRUE)
  )
  celltypes_gpt[,cl_name:=factor(cl_name,
                                 levels=rev(unique(celltypes_gpt$cl_name)),
                                 ordered = TRUE)]

  agg_gpt <- (
    results_gpt[q<q_threshold,
                c("cl_id","cl_name",
                  "death","intellectual_disability","impaired_mobility","physical_malformations",
                  "blindness","sensory_impairments","immunodeficiency","cancer",
                  "reduced_fertility","congenital_onset")]|>
      data.table::melt.data.table(id.vars=c("cl_id","cl_name"))
  )[,cl_name:=factor(cl_name,levels=levels(celltypes_gpt$cl_name), ordered = TRUE)
  ][,value:=factor(value,levels = c(0:3),ordered = TRUE)]
  agg_gpt[,variable:=gsub("_","\n",variable)]
  agg_gpt[,variable:=factor(variable,levels = unique(variable),ordered = TRUE)]

  out <- list()
  if("bar" %in% types){
    messager("Creating bar plot.")
    gg_severity <- ggplot2::ggplot(celltypes_gpt,
                                   ggplot2::aes(x=cl_name,y=severity_score_gpt,
                                                fill=severity_score_gpt)) +
      ggplot2::scale_fill_viridis_c(option="plasma") +
      ggplot2::geom_bar(stat="identity")+
      ggplot2::labs(y="GPT Severity Score", x=NULL,
                    fill="GPT\nseverity\nscore") +
      ggplot2::coord_flip()+
      ggplot2::theme_minimal()
    gg_annot <- ggplot2::ggplot(agg_gpt[!is.na(value)],
                                ggplot2::aes(x=cl_name, y=1, fill=value))+
      ggplot2::facet_grid(.~variable, scales="free_y")+
      ggplot2::geom_bar(stat="identity", position="fill")+
      ggplot2::scale_fill_viridis_d(option="plasma",
                                    labels=c(`0`="never",
                                             `1`="rarely",
                                             `2`="often",
                                             `3`="always")) +
      ggplot2::scale_y_continuous(breaks = c(0,.5,1), labels=c("0","0.5","1")) +
      ggplot2::labs(y="Proportion of associated phenotypes", x="Cell type") +
      ggplot2::coord_flip()+
      ggplot2::theme_minimal()

    out[["bar"]][["plot"]] <- (gg_annot | gg_severity) +
      patchwork::plot_layout(axes = "collect", widths = c(1,.1))
    out[["bar"]][["data"]] <- celltypes_gpt
  }

  if("dot" %in% types){
    messager("Creating dot plot.")
    ## Which cell type most often causes death when disrupted?
    gpt_celltypes <- (lapply(unique(agg_gpt$variable), function(x){
      agg <- agg_gpt[variable==x,
                     list(value=mean(as.numeric(value),na.rm=TRUE)),
                     by=c("cl_name","cl_id")]|>
        data.table::setorderv("value",-1)
      cbind(
        variable=x,
        data.table::rbindlist(
          list(
            top=utils::head(agg, top_n),
            bottom=utils::tail(agg, top_n)
          ), idcol = "group"
        )
      )
    })|>data.table::rbindlist()
    )
    gpt_celltypes[,variable:=factor(gsub("_"," ",variable),
                        levels = gsub("_"," ",levels(gpt_celltypes$variable)),
                        ordered = TRUE)]
    cl_dend <- KGExplorer::ontology_to(cl,
                                       to="dendrogram")
    ### Get leaf order from dendrogram
    gpt_celltypes[,cl_id:=factor(cl_id,
                                 levels=labels(cl_dend),
                                 ordered = TRUE)]
    ## order cl_name the same as cl_id
    gpt_celltypes[,cl_name:=factor(
      cl_name,
      levels=unique(gpt_celltypes[order(cl_id),cl_name]),
      ordered = TRUE)]

    out[["dot"]][["data"]] <- gpt_celltypes
    out[["dot"]][["plot"]] <- ggplot2::ggplot(gpt_celltypes,
                    ggplot2::aes(x=variable, y=cl_name,
                                 color=group, size=value))+
      # ggplot2::geom_bar(stat="identity")+
      ggplot2::geom_point(alpha=.7) +
      ggplot2::scale_color_manual(values=c("bottom"="red","top"="blue"),
                                  breaks = c("top","bottom"))+
      ggplot2::scale_size_continuous(name="Mean\nannotation\nvalue",
                                     limits=c(0,3),
                                     range=c(-1,5),
                                     breaks = c(0:3),
                                     labels=c(`0`="never (0)",
                                              `1`="rarely (1)",
                                              `2`="often (2)",
                                              `3`="always (3)"))+
      ggplot2::scale_x_discrete(position = "top") +
      ggplot2::labs(x="GPT annotation")+
      ggplot2::theme_bw()+
      ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 0),
                     panel.background = ggplot2::element_blank(),
                     axis.line = ggplot2::element_line(colour = "black"))
  }
  return(out)
}
neurogenomics/MultiEWCE documentation built on Aug. 24, 2024, 1:41 a.m.