R/plot_celltype_severity.R

Defines functions plot_celltype_severity

Documented in plot_celltype_severity

#' Plot cell type severity
#'
#' Plot the phenotype severity scores (generated by GPT-4) aggregated by
#' the cell types each phenotype is significantly associated with.
#' @param top_n Top and bottom number of cell types to show per annotation
#'  (used in dot plot only).
#' @param types Which types of plots of create.
#' @param force_new Run a new set of enrichment tests even when cached
#' results are found. Only used when \code{run_enrichment=FALSE}.
#' @param run_enrichment Instead of simply taking the top N results,
#' run a series of one-sided Wilcoxan rank-sum tests to determine whether the
#' distribution of ordinal severity values
#' (never=0, rarely=1, often=2, always=3) are significantly different between
#' a given cell type and all other cell types.
#' Tests are repeated across each GPT annotation separately using
#' \link[dplyr]{group_by} and \link[rstatix]{wilcox_test}.
#' @param nonsig_fill Fill colour for non-significant results.
#' @inheritParams prioritise_targets
#' @inheritParams plot_bar_dendro
#' @inheritParams ggplot2::theme_bw
#' @inheritParams KGExplorer::set_cores
#' @returns Named list of ggplot and data.table objects.
#'
#' @export
#' @examples
#' set.seed(2025)
#' results <- load_example_results()
#' results <- results[sample(seq(nrow(results)), 5000),]
#' out <- plot_celltype_severity(results)
plot_celltype_severity <- function(results,
                                   cl = get_cl(),
                                   q_threshold=.05,
                                   run_enrichment=TRUE,
                                   top_n=3,
                                   types=c("dot","bar")[1],
                                   run_prune_ancestors=FALSE,
                                   nonsig_fill=ggplot2::alpha("grey90",.001),
                                   force_new=FALSE,
                                   base_size=8,
                                   save_path=tempfile(fileext = ".rds"),
                                   workers=1){

  requireNamespace("ggplot2")
  severity_score_gpt <- cl_name <- cl_id <- value <- variable <- p <- FDR <-
    statistic <- NULL;

  results <- map_celltype(results)
  results_gpt <- HPOExplorer::add_gpt_annotations(results)
  celltypes_gpt <- (
    results_gpt[q<q_threshold,
                list(severity_score_gpt=mean(severity_score_gpt,na.rm=TRUE)),
                by=c("cl_id","cl_name")]|>
      data.table::setorderv("severity_score_gpt",-1, na.last = TRUE)
  )
  agg_gpt <- (
    results_gpt[q<q_threshold,
                c("cl_id","cl_name",
                  "death","intellectual_disability",
                  "impaired_mobility","physical_malformations",
                  "blindness","sensory_impairments","immunodeficiency","cancer",
                  "reduced_fertility","congenital_onset")]|>
      data.table::melt.data.table(id.vars=c("cl_id","cl_name"))
  )[,value:=factor(value,levels = c(0:3),ordered = TRUE)]
  agg_gpt[,variable:=gsub("_"," ",variable)]
  agg_gpt[,variable:=factor(variable,levels = unique(variable),ordered = TRUE)]
  # Prune redundant cell types
  if(run_prune_ancestors){
    agg_gpt <- KGExplorer::prune_ancestors(dat = agg_gpt,
                                           id_col = "cl_id",
                                           ont = cl)
  }

  out <- list()
  if("bar" %in% types){
    messager("Creating bar plot.")
    gg_severity <- ggplot2::ggplot(celltypes_gpt,
                                   ggplot2::aes(x=cl_name,y=severity_score_gpt,
                                                fill=severity_score_gpt)) +
      ggplot2::scale_fill_viridis_c(option="plasma") +
      ggplot2::geom_bar(stat="identity")+
      ggplot2::labs(y="GPT Severity Score", x=NULL,
                    fill="GPT\nseverity\nscore") +
      ggplot2::coord_flip()+
      ggplot2::theme_minimal()
    gg_annot <- ggplot2::ggplot(agg_gpt[!is.na(value)],
                                ggplot2::aes(x=cl_name, y=1, fill=value))+
      ggplot2::facet_grid(.~variable, scales="free_y")+
      ggplot2::geom_bar(stat="identity", position="fill")+
      ggplot2::scale_fill_viridis_d(option="plasma",
                                    labels=c(`0`="never",
                                             `1`="rarely",
                                             `2`="often",
                                             `3`="always")) +
      ggplot2::scale_y_continuous(breaks = c(0,.5,1), labels=c("0","0.5","1")) +
      ggplot2::labs(y="Proportion of associated phenotypes", x="Cell type") +
      ggplot2::coord_flip()+
      ggplot2::theme_minimal()

    out[["bar"]][["plot"]] <- (gg_annot | gg_severity) +
      patchwork::plot_layout(axes = "collect", widths = c(1,.1))
    out[["bar"]][["data"]] <- celltypes_gpt
  }

  if("dot" %in% types){
    messager("Creating dot plot.")
    ## Which cell type most often causes death when disrupted?
    if(isTRUE(run_enrichment)){
      ## Enrichment approach
      if(!is.null(save_path) &&
         file.exists(save_path) &&
         isFALSE(force_new)){
        ## Use cached results
        messager("Importing cached Wilcoxon rank-sum test results:",save_path)
        wt_res <- readRDS(save_path)
      } else {
        ## Run new tests
        BPPARAM <- KGExplorer::set_cores(workers = workers)
        messager("Running Wilcoxon rank-sum tests:")
        wt_res <- BiocParallel::bplapply(
          unique(agg_gpt$cl_id),
          function(ct){
            # messager("Running Wilcoxon rank-sum test:",ct)
            tmp <- agg_gpt[,group:=cl_id==ct][,value:=as.numeric(value)]
            tmp <- tmp[!is.na(value)]
            # Remove groups there's not enough samples to run tests on
            tmp[,n_samples:=data.table::uniqueN(value), by=variable]
            tmp[,n_groups:=data.table::uniqueN(group), by=variable]
            tmp <- tmp[n_samples>=2 & n_groups>=2]
            if (nrow(tmp) == 0) {
              return(NULL)
            }

            tmp |>
              dplyr::group_by(variable)|>
              rstatix::wilcox_test(value ~ group,
                                   ref.group = "FALSE",
                                   # exact = TRUE,
                                   alternative = "less"
              ) |>
              dplyr::mutate(p=ifelse(p==0,.Machine$double.xmin,p)) |>
              dplyr::mutate(FDR=stats::p.adjust(p,method="fdr"))|>
              dplyr::mutate(cl_id=ct,
                            cl_name=tmp[cl_id==ct]$cl_name[1],
                            .before=0)|>
              data.table::data.table()
          }, BPPARAM =  BPPARAM) |> data.table::rbindlist()
        if(!is.null(save_path)){
          messager("Caching Wilcoxon rank-sum test results:",save_path)
          dir.create(dirname(save_path), showWarnings = FALSE, recursive = TRUE)
          saveRDS(wt_res,save_path)
        }
      }
      ## Merge with original data and agg
      dat <- (merge(agg_gpt, wt_res)
              )[,list(value=mean(as.numeric(value),na.rm=TRUE),
                      statistic=mean(as.numeric(statistic),na.rm=TRUE),
                      p=mean(as.numeric(p),na.rm=TRUE),
                      FDR=mean(as.numeric(FDR),na.rm=TRUE)),
                by=c("variable","cl_name","cl_id")]
      ## Cluster by data values
      hc <- data.table::dcast.data.table(dat,
                                        formula = "cl_id ~ variable",
                                        value.var = "p") |>
        KGExplorer::dt_to_matrix() |>
        stats::dist()|>
        stats::hclust()
      order_celltypes(dt = dat,
                      cl = cl,
                      levels = hc$labels[hc$order])
      out[["dot"]][["data"]] <- dat
      out[["dot"]][["plot"]] <- ggplot2::ggplot(dat,
                                                ggplot2::aes(x=variable, y=cl_name,
                                                             color=-log2(FDR),
                                                             size=value))+
        ggplot2::geom_point() +
        ggplot2::scale_color_gradient2(limits=c(-log2(.05),max(-log2(dat$FDR))),
                                       low=nonsig_fill,
                                       mid=ggplot2::alpha("blue", .7),
                                       high = ggplot2::alpha("red", .7),
                                       na.value = nonsig_fill)+
        ggplot2::scale_size_continuous(name="Mean\nannotation\nvalue",
                                       limits=c(0,3),
                                       range=c(-1,5),
                                       breaks = c(0:3),
                                       labels=c(`0`="never (0)",
                                                `1`="rarely (1)",
                                                `2`="often (2)",
                                                `3`="always (3)"))+
        ggplot2::scale_x_discrete(position = "top") +
        ggplot2::labs(x="GPT annotation", y="Cell type",
                      color=expression(-log[2]~(FDR))
                      )+
        ggplot2::theme_bw(base_size = base_size)+
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45,
                                                           hjust = 0),
                       panel.background = ggplot2::element_blank(),
                       axis.line = ggplot2::element_line(colour = "black"),
                       legend.justification = "top")

    } else {
      ## top_n approach
      gpt_celltypes <- (lapply(unique(agg_gpt$variable), function(x){
        agg <- agg_gpt[variable==x,
                       list(value=mean(as.numeric(value),na.rm=TRUE)),
                       by=c("cl_name","cl_id")]|>
          data.table::setorderv("value",-1)
        cbind(
          variable=x,
          data.table::rbindlist(
            list(
              top=utils::head(agg, top_n),
              bottom=utils::tail(agg, top_n)
            ), idcol = "group"
          )
        )
      })|>data.table::rbindlist()
      )
      order_celltypes(dt = gpt_celltypes,
                      cl = cl)
      out[["dot"]][["data"]] <- gpt_celltypes
      out[["dot"]][["plot"]] <- ggplot2::ggplot(gpt_celltypes,
                                                ggplot2::aes(x=variable, y=cl_name,
                                                             color=group, size=value))+
        ggplot2::geom_point(alpha=.7) +
        ggplot2::scale_color_manual(values=c("bottom"="red","top"="blue"),
                                    breaks = c("top","bottom"))+
        ggplot2::scale_size_continuous(name="Mean\nannotation\nvalue",
                                       limits=c(0,3),
                                       range=c(-1,5),
                                       breaks = c(0:3),
                                       labels=c(`0`="never (0)",
                                                `1`="rarely (1)",
                                                `2`="often (2)",
                                                `3`="always (3)"))+
        ggplot2::scale_x_discrete(position = "top") +
        ggplot2::labs(x="GPT annotation", y="Cell type")+
        ggplot2::theme_bw()+
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 0),
                       panel.background = ggplot2::element_blank(),
                       axis.line = ggplot2::element_line(colour = "black"),
                       legend.justification = "top")
    }
  }
  return(out)
}
neurogenomics/MultiEWCE documentation built on April 17, 2025, 9:27 p.m.