R/plotting.R

Defines functions ModuleRadarPlot PlotModulePreservationLollipop ModuleTopologyBarplot ModuleTopologyHeatmap PlotLollipop PlotDMEsLollipop PlotDMEsVolcano PlotKMEs PlotModuleTraitCorrelation PlotModulePreservation DoHubGeneHeatmap MotifOverlapBarPlot OverlapBarPlot OverlapDotPlot ModuleUMAPPlot HubGeneNetworkPlot ModuleNetworkPlot EnrichrDotPlot EnrichrBarPlot ModuleFeaturePlot ModuleCorrNetwork ModuleCorrelogram PlotDendrogram PlotSoftPowers

Documented in DoHubGeneHeatmap EnrichrBarPlot EnrichrDotPlot HubGeneNetworkPlot ModuleCorrelogram ModuleCorrNetwork ModuleFeaturePlot ModuleNetworkPlot ModuleRadarPlot ModuleTopologyBarplot ModuleTopologyHeatmap ModuleUMAPPlot MotifOverlapBarPlot OverlapBarPlot OverlapDotPlot PlotDendrogram PlotDMEsLollipop PlotDMEsVolcano PlotKMEs PlotLollipop PlotModulePreservation PlotModulePreservationLollipop PlotModuleTraitCorrelation PlotSoftPowers

#' PlotSoftPowers
#'
#' Plot Soft Power Threshold results
#'
#' @param seurat_obj A Seurat object
#' @param selected_power power to highlight in the plots
#' @param point_size the size of the points in the plot
#' @param text_size the size of the text in the plot
#' @param plot_connectivity logical indicating whether to plot the connectivity in addition to the scale free topplogy fit.
#' @param wgcna_name The name of the WGCNA experiment in seurat_obj
#' @keywords scRNA-seq
#' @export
#' @examples
#' PlotSoftPowers
PlotSoftPowers <- function(
  seurat_obj,
  selected_power = NULL,
  point_size = 5,
  text_size=3,
  plot_connectivity = TRUE,
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  pt <- GetPowerTable(seurat_obj, wgcna_name)

  if("group" %in% colnames(pt)){
    print("here")
    # select soft power for each group:
    power_tables <- pt %>% dplyr::group_split(group)
    soft_powers <- sapply(power_tables, function(power_table){
      power_table %>% subset(SFT.R.sq >= 0.8) %>% .$Power %>% min
    })

  } else{

    # get soft power:
    if(is.null(selected_power)){
      soft_power <- pt %>% subset(SFT.R.sq >= 0.8) %>% .$Power %>% min
    } else{
      soft_power <- selected_power
    }
    soft_powers <- NULL
    power_tables <- list("power" = pt)

  }

  plot_list <- list()

  for(i in 1:length(power_tables)){

    pt <- power_tables[[i]]

    if(!is.null(soft_powers)){
      soft_power <- soft_powers[i]
      print(i)
      print(soft_power)

      pt <- pt %>% dplyr::select(-group)
    }

    print(head(pt))

    # get other params:
    sft_r <- as.numeric(pt[pt$Power == soft_power,'SFT.R.sq'])
    mean_k <- as.numeric(pt[pt$Power == soft_power,'mean.k.'])
    median_k <- as.numeric(pt[pt$Power == soft_power,'median.k.'])
    max_k <- as.numeric(pt[pt$Power == soft_power,'max.k.'])

    # set color of text
    pt$text_color <- ifelse(
      pt$Power == soft_power, 'white', 'black'
    )

    # plot for soft power thresh:
    p1 <- pt %>% ggplot(aes(x=Power, y=SFT.R.sq)) +
      geom_rect(
        data = pt[1,],
        aes(xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=0.8), fill='grey80', alpha=0.8, color=NA
      ) +
      geom_hline(yintercept = sft_r, linetype='dashed') +
      geom_vline(xintercept = soft_power, linetype= 'dashed') +
      geom_point(
        data = pt[pt$Power == soft_power,c('Power', 'SFT.R.sq')],
        aes(x=Power, y=SFT.R.sq),
        inherit.aes=FALSE,
        color = 'black',
        size=point_size
      ) +
      geom_text(label=pt$Power, color = pt$text_color, size=text_size) +
      scale_y_continuous(limits = c(0,1), breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1)) +
      ylab('Scale-free Topology Model Fit') +
      xlab('Soft Power Threshold') +
      theme(
        axis.line.x = element_blank(),
        axis.line.y = element_blank(),
        panel.border = element_rect(colour = "black", fill=NA, size=1)
      )

      if(plot_connectivity){

        # plot for mean connectivity:
        p2 <- pt %>% ggplot(aes(x=Power, y=mean.k.)) +
          geom_hline(yintercept = mean_k, linetype='dashed') +
          geom_vline(xintercept = soft_power, linetype= 'dashed') +
          geom_point(
            data = pt[pt$Power == soft_power,c('Power', 'mean.k.')],
            aes(x=Power, y=mean.k.),
            inherit.aes=FALSE,
            color = 'black',
            size=point_size
          ) +
          geom_text(label=pt$Power, color = pt$text_color, size=text_size) +
          scale_y_continuous(labels=scales::comma) +
          ylab('Mean Connectivity') +
          xlab('Soft Power Threshold') +
          theme(
            axis.line.x = element_blank(),
            axis.line.y = element_blank(),
            panel.border = element_rect(colour = "black", fill=NA, size=1)
          )

          # plot for medianan connectivity:
          p3 <- pt %>% ggplot(aes(x=Power, y=median.k.)) +
            geom_hline(yintercept = median_k, linetype='dashed') +
            geom_vline(xintercept = soft_power, linetype= 'dashed') +
            geom_point(
              data = pt[pt$Power == soft_power,c('Power', 'median.k.')],
              aes(x=Power, y=median.k.),
              inherit.aes=FALSE,
              color = 'black',
              size=point_size
            ) +
            geom_text(label=pt$Power, color = pt$text_color, size=text_size) +
            scale_y_continuous(labels=scales::comma) +
            ylab('Median Connectivity') +
            xlab('Soft Power Threshold') +
            theme(
              axis.line.x = element_blank(),
              axis.line.y = element_blank(),
              panel.border = element_rect(colour = "black", fill=NA, size=1)
            )

          # plot for mean connectivity:
          p4 <- pt %>% ggplot(aes(x=Power, y=max.k.)) +
            geom_hline(yintercept = max_k, linetype='dashed') +
            geom_vline(xintercept = soft_power, linetype= 'dashed') +
            geom_point(
              data = pt[pt$Power == soft_power,c('Power', 'max.k.')],
              aes(x=Power, y=max.k.),
              inherit.aes=FALSE,
              color = 'black',
              size=point_size
            ) +
            geom_text(label=pt$Power, color = pt$text_color, size=text_size) +
            scale_y_continuous(labels=scales::comma) +
            ylab('Max Connectivity') +
            xlab('Soft Power Threshold') +
            theme(
              axis.line.x = element_blank(),
              axis.line.y = element_blank(),
              panel.border = element_rect(colour = "black", fill=NA, size=1)
            )


          plot_list[[i]] <- list(p1,p2,p3,p4)

      } else{
        plot_list[[i]] <- p1
      }

  }

  if(length(plot_list) == 1){
    return(plot_list[[1]])
  }
  plot_list

}



#' PlotDendrogram
#'
#' Plot WGCNA dendrogram
#'
#' @param seurat_obj A Seurat object
#' @keywords scRNA-seq
#' @export
#' @examples
#' PlotDendrogram
PlotDendrogram <- function(
  seurat_obj, groupLabels="Module colors", wgcna_name=NULL,
  dendroLabels = FALSE, hang = 0.03, addGuide = TRUE, guideHang = 0.05,
  main = "", ...
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get WGCNA network and module data
  net <- GetNetworkData(seurat_obj, wgcna_name)
  modules <- GetModules(seurat_obj, wgcna_name)

  # plot dendrogram
  WGCNA::plotDendroAndColors(
    net$dendrograms[[1]],
    as.character(modules$color),
    groupLabels=groupLabels,
    dendroLabels = dendroLabels,
    hang = hang,
    addGuide = addGuide,
    guideHang = guideHang,
    main = main,
    ...
  )
}

#' ModuleCorrelogram
#'
#' Plot Module Eigengene correlogram
#'
#' @param seurat_obj A Seurat object
#' @param features What to plot? Can select hMEs, MEs, scores, or average
#' @keywords scRNA-seq
#' @export
#' @examples
#' MECorrelogram
ModuleCorrelogram <- function(
  seurat_obj, MEs2=NULL,
  features = 'hMEs',
  order='original', method='ellipse',
  exclude_grey=TRUE, type='upper',
  tl.col = 'black', tl.srt=45,
  sig.level = c(0.0001, 0.001, 0.01, 0.05),
#  insig='label_sig',
  pch.cex=0.7, col=NULL, ncolors=200,
  wgcna_name=NULL, wgcna_name2=NULL, ...
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get MEs, module data from seurat object
  if(features == 'hMEs'){
    MEs <- GetMEs(seurat_obj, TRUE, wgcna_name)
  } else if(features == 'MEs'){
    MEs <- GetMEs(seurat_obj, FALSE, wgcna_name)
  } else if(features == 'scores'){
    MEs <- GetModuleScores(seurat_obj, wgcna_name)
  } else if(features == 'average'){
    MEs <- GetAvgModuleExpr(seurat_obj, wgcna_name)
    restrict_range <- FALSE
  } else(
    stop('Invalid feature selection. Valid choices: hMEs, MEs, scores, average')
  )

  # convert to matrix
  MEs <- as.matrix(MEs)

  # exclude grey
  if(exclude_grey){
    MEs <- MEs[,colnames(MEs) != 'grey']
  }

  # setup color scheme:
  if(is.null(col)){
    colfunc <- grDevices::colorRampPalette(c('seagreen',  'white', 'darkorchid1'))
    col = colfunc(ncolors)
  }

  # perform correlation
  if(is.null(MEs2)){
    res <- Hmisc::rcorr(x=MEs)
  } else{

    # add dataset indicator to cols/rows
    d1_names <- colnames(MEs); d2_names <- colnames(MEs2);
    colnames(MEs) <- paste0(d1_names, '_D1')
    colnames(MEs2) <- paste0(d2_names, '_D2')

    res <- Hmisc::rcorr(x=MEs, y=as.matrix(MEs2))

    res$r <- res$r[!grepl('_D1', colnames(res$r)),grepl('_D1', colnames(res$r))]
    colnames(res$r) <- d1_names
    rownames(res$r) <- d2_names

    res$P <- res$P[!grepl('_D1', colnames(res$P)),grepl('_D1', colnames(res$P))]
    colnames(res$P) <- d1_names
    rownames(res$P) <- d2_names

  }
  res$P[is.na(res$P)] <- 0

  # plot correlogram
  corrplot::corrplot(
    res$r,
    p.mat = res$P,
    type=type, order=order,
    method=method, tl.col=tl.col,
    tl.srt=tl.srt, sig.level=sig.level,
    # insig=insig,
    pch.cex=pch.cex,
    col = col, ...
  )

}


#' ModuleCorrNetworks
#'
#' Plot Module Eigengene correlogram
#'
#' @param seurat_obj A Seurat object
#' @keywords scRNA-seq
#' @export
#' @examples
#' ModuleCorrNetwork
ModuleCorrNetwork <- function(
  seurat_obj, cluster_col=NULL, exclude_grey=TRUE,
  features = 'hMEs',
  reduction='umap', cor_cutoff=0.2, label_vertices=FALSE, edge_scale=5,
  vertex_size=15, niter=100, vertex_frame=FALSE, wgcna_name=NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # Get module eigengenes
  if(features == 'hMEs'){
    MEs <- GetMEs(seurat_obj, TRUE, wgcna_name)
  } else if(features == 'MEs'){
    MEs <- GetMEs(seurat_obj, FALSE, wgcna_name)
  } else if(features == 'scores'){
    MEs <- GetModuleScores(seurat_obj, wgcna_name)
  } else if(features == 'average'){
    MEs <- GetAvgModuleExpr(seurat_obj, wgcna_name)
    restrict_range <- FALSE
  } else(
    stop('Invalid feature selection. Valid choices: hMEs, MEs, scores, average')
  )

  # get modules
  modules <- GetModules(seurat_obj, wgcna_name)

  # exclude grey
  if(exclude_grey){
    MEs <- MEs[,colnames(MEs) != 'grey']
    modules <- modules %>% subset(color != 'grey')
  }

  # get list of modules
  mods  <-  colnames(MEs)

  # what clusters are we using?
  if(is.null(cluster_col)){
    clusters <- Idents(seurat_obj)
  } else{
    clusters <- droplevels(seurat_obj@meta.data[[cluster_col]])
  }

  MEs$cluster <- clusters

  # compute average MEs for each cluster
  cluster_ME_av <- do.call(
    rbind, lapply(
      split(MEs, MEs$cluster),
      function(x){colMeans(x[,mods])
  })) %>% as.data.frame

  # which cluster mas highest expression of each module?
  top_clusters <- sapply(mods, function(x){
    rownames(cluster_ME_av[cluster_ME_av[,x] == max(cluster_ME_av[,x]),])
  })

  # get UMAP / tSNE centroids for these clusters to use as starting coordinates
  red_df <- as.data.frame(seurat_obj@reductions[[reduction]]@cell.embeddings)
  red_df$cluster <- clusters

  # compute average coords for each cluster
  red_av <- do.call(
    rbind, lapply(
      split(red_df, red_df$cluster),
      function(x){colMeans(x[,1:2])
  })) %>% as.data.frame


  # get correlation matrix
  cor_mat <- Hmisc::rcorr(as.matrix(MEs[,1:ncol(MEs)-1]))$r
  cor_mat[lower.tri(cor_mat)] <- NA

  # melt matrix and remove NAs
  cor_df <- reshape2::melt(cor_mat) %>% na.omit

  # remove self-edges
  cor_df <- cor_df %>% subset(!(Var1 == Var2))

  # remove weak edges:
  cor_df <- cor_df %>% subset(abs(value) >= cor_cutoff)

  # vertex df:
  v_df <- data.frame(
    name = mods,
    cluster = as.character(top_clusters)
  )

  # add module colors:
  unique_mods <- dplyr::distinct(modules[,c('module', 'color')])
  rownames(unique_mods) <- unique_mods$module
  v_df$color <- unique_mods[v_df$name,'color']

  # add reduction coords:
  v_df$x <- red_av[v_df$cluster, 1]
  v_df$y <- red_av[v_df$cluster, 2]

  # make graph:
  g <- igraph::graph_from_data_frame(
    cor_df,
    directed=FALSE,
    vertices=v_df
  )

  # igraph layout
  e <-  get.edgelist(g, name=FALSE)
  # l <- igraph::layout_with_fr(
  #   g,
  #   weights=igraph::E(g)$value,
  #   coords = as.matrix(data.frame(x=igraph::V(g)$x, y=igraph::V(g)$y)),
  #   niter=niter
  # )
  #
  l <- qgraph::qgraph.layout.fruchtermanreingold(
    e, vcount = vcount(g),
    weights=igraph::E(g)$value,
    repulse.rad=(vcount(g)),
    #cool.exp = 0.5,
    # init = as.matrix(data.frame(x=igraph::V(g)$x, y=igraph::V(g)$y)),
    niter=niter,
    #max.delta = vcount(g)/2
  )


  # ggplot just to get the colors
  plot_df <- rbind(cor_df, data.frame(Var1=c('x', 'y'), Var2=c('y', 'x'), value=c(-1,1)))
  temp <- ggplot(plot_df, aes(x=value, y=value, color=value)) +
    geom_point() + scale_color_gradient2(high='darkorchid1', mid='white', low='seagreen', midpoint=0)
  temp <- ggplot_build(temp)
  igraph::E(g)$color <- temp$data[[1]]$colour[1:nrow(cor_df)]

  # label the vertices?
  if(label_vertices){labels <- igraph::V(g)$name} else{labels <- NA}

  # vertex_frame
  if(vertex_frame){frame_color <- 'black'} else{frame_color <- igraph::V(g)$color}

  # plot the graph
  plot(
    g, layout=l,
    edge.color=igraph::E(g)$color,
    edge.curved=0,
    edge.width=abs(igraph::E(g)$value) * edge_scale,
    vertex.color=igraph::V(g)$color,
    vertex.frame.color=frame_color,
    vertex.label=labels,
    vertex.label.family='Helvetica',
    vertex.label.color = 'black',
    vertex.label.cex=0.5,
    vertex.size=vertex_size
  )
}


#' ModuleFeaturePlot
#'
#' Plot module eigengenes as a FeaturePlot
#'
#' @param seurat_obj A Seurat object
#' @param features What to plot? Can select hMEs, MEs, scores, or average
#' @param order TRUE, FALSE, or "shuffle" are valid options
#' @keywords scRNA-seq
#' @export
#' @examples
#' ModuleFeaturePlot
ModuleFeaturePlot<- function(
  seurat_obj, module_names=NULL, wgcna_name = NULL,
  reduction='umap', features = 'hMEs',
  order_points=TRUE, restrict_range=TRUE, point_size = 0.5, alpha=1,
  label_legend = FALSE, ucell = FALSE, raster=FALSE, raster_dpi=500,
  raster_scale=1, plot_ratio = 1, title=TRUE
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get MEs, module data from seurat object
  if(features == 'hMEs'){
    MEs <- GetMEs(seurat_obj, TRUE, wgcna_name)
  } else if(features == 'MEs'){
    MEs <- GetMEs(seurat_obj, FALSE, wgcna_name)
  } else if(features == 'scores'){
    MEs <- GetModuleScores(seurat_obj, wgcna_name)
  } else if(features == 'average'){
    MEs <- GetAvgModuleExpr(seurat_obj, wgcna_name)
    restrict_range <- FALSE
  } else(
    stop('Invalid feature selection. Valid choices: hMEs, MEs, scores, average')
  )

  # override restrict_range if ucell
  if(ucell){restrict_range <- FALSE}

  # use all modules except gray if not specified by the user
  modules <- GetModules(seurat_obj, wgcna_name)
  if(is.null(module_names)){
    module_names <- levels(modules$module)
    module_names <- module_names[module_names != 'grey']
  }

  # get  reduction from seurat obj
  umap <- seurat_obj@reductions[[reduction]]@cell.embeddings
  x_name <- colnames(umap)[1]
  y_name <- colnames(umap)[2]

  # merge into one df for plotting
  plot_df <- cbind(umap, MEs) %>% as.data.frame()

  plot_list <- list()
  for(cur_mod in module_names){

    print(cur_mod)

    # get the color for this module:
    cur_color <- modules %>% subset(module == cur_mod) %>% .$color %>% unique

    # reset the range of the plot:
    plot_range <- plot_df[,cur_mod] %>% range
    if(restrict_range){
      if(abs(plot_range[1]) > abs(plot_range[2])){
        plot_range[1] <- -1*plot_range[2]
      } else{
        plot_range[2] <- -1*plot_range[1]
      }
      plot_df[,cur_mod] <- ifelse(plot_df[,cur_mod] > plot_range[2], plot_range[2], plot_df[,cur_mod])
      plot_df[,cur_mod] <- ifelse(plot_df[,cur_mod] < plot_range[1], plot_range[1], plot_df[,cur_mod])
    }

    cur_plot_df <- plot_df[,c(colnames(umap), cur_mod)]
    colnames(cur_plot_df)[3] <- "val"

    # order points:
    if(order_points == TRUE){
      cur_plot_df <- cur_plot_df %>% dplyr::arrange(val)
    } else if(order_points == "shuffle"){
      cur_plot_df <- cur_plot_df[sample(nrow(cur_plot_df)),]
    }

    # plot with ggplot
    p <- cur_plot_df %>%
      ggplot(aes_string(x=x_name, y=y_name, color="val"))

    # rasterise?
    if(raster){
      p <- p + ggrastr::rasterise(geom_point(size=point_size, alpha=alpha), dpi=raster_dpi, scale=raster_scale)
    } else{
      p <- p + geom_point(size=point_size, alpha=alpha)
    }

    # add title and theme:
    p <- p + umap_theme() + labs(color="")

    if(title){
      p <- p + ggtitle(cur_mod)
    }

    # aspect ratio:
    if(is.numeric(plot_ratio)){
      p <- p + coord_fixed(ratio = plot_ratio)
    }

    # UCell?
    if(!ucell){
      p <- p + scale_color_gradient2(
        low='grey75', mid='grey95', high=cur_color,
        breaks = plot_range,
        labels = c('-', '+'),
        guide = guide_colorbar(ticks=FALSE, barwidth=0.5, barheight=4)
      )
    } else{
      p <- p + scale_color_gradient(
        low='grey95', high=cur_color,
        breaks = plot_range,
        labels = c('0', '+'),
        guide = guide_colorbar(ticks=FALSE, barwidth=0.5, barheight=4)
      )
    }
    plot_list[[cur_mod]] <- p

  }

  # return plot
  if(length(plot_list) == 1){
    p <- plot_list[[1]]
  } else{
    p <- plot_list
  }

  p

}


#' EnrichrBarPlot
#'
#' Generates bar plots from Enrichr data to visualize enriched terms for hdWGCNA modules in a Seurat object. 
#' The function outputs a PDF file for each module, with separate bar plots for each database.
#'
#' @param seurat_obj A Seurat object
#' @param outdir A string specifying the directory where the output PDF files will be saved. Default is "enrichr_plots".
#' @param n_terms An integer indicating the number of top enriched terms to include in each bar plot. Default is 25.
#' @param p_cutoff A numeric value specifying the significance threshold for including terms (p-value or adjusted p-value).
#'                 Only terms with p-values below this threshold are plotted. Default is 0.05.
#' @param p_adj Logical indicating whether to use the adjusted p-value (default: TRUE) or raw p-value for filtering terms.
#' @param plot_size A numeric vector of length 2 specifying the width and height of the output PDF files in inches. Default is c(6, 15).
#' @param logscale Logical specifying whether to log-transform the enrichment scores before plotting. Default is FALSE.
#' @param plot_bar_color A string specifying the color of the bars in the bar plots. If NULL (default), bars are colored according 
#'                       to the module's assigned color.
#' @param plot_text_color A string specifying the color of the text labels on the bar plots. If NULL (default), the color is 
#'                        automatically determined based on the bar color.
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#'
#' @details 
#' This function processes the Enrichr output stored in a Seurat object, filters enriched terms by significance, and generates 
#' bar plots for each WGCNA module. Separate plots are created for each database included in the Enrichr results. The top 
#' enriched terms for each module and database are ordered by their combined enrichment score. If there are ties in the enrichment 
#' scores, the function uses all tied terms.
#'
#' The bar plot text and bar colors can be customized, and plots can be saved with log-transformed enrichment scores if specified.
#' Text wrapping is applied to long term names for better readability.
#'
#' @return 
#' Generates PDF files in the specified output directory, with one file per WGCNA module. Each file contains bar plots for 
#' all databases associated with that module.
#'
#' @export
EnrichrBarPlot <- function(
  seurat_obj, 
  outdir = "enrichr_plots",
  n_terms = 25, 
  p_cutoff = 0.05,
  p_adj = TRUE,
  plot_size = c(6,15),
  logscale=FALSE, 
  plot_bar_color=NULL,
  plot_text_color=NULL,
  wgcna_name=NULL
){

  # get data from active assay if wgcna_name is not given
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get modules:
  modules <- GetModules(seurat_obj, wgcna_name)
  mods <- levels(modules$module)
  mods <- mods[mods != 'grey']

  # get Enrichr table
  enrichr_df <- GetEnrichrTable(seurat_obj, wgcna_name)
  dbs <- as.character(unique(enrichr_df$db))

  # subset based on significance level:
  if(p_adj){
    enrichr_df <- subset(enrichr_df, Adjusted.P.value <= p_cutoff)
  } else{
    enrichr_df <- subset(enrichr_df, P.value <= p_cutoff)
  }

  # helper function to wrap text
  wrapText <- function(x, len) {
      sapply(x, function(y) paste(strwrap(y, len), collapse = "\n"), USE.NAMES = FALSE)
  }

  # make output dir if it doesn't exist:
  if(!dir.exists(outdir)){dir.create(outdir)}

  # loop through modules:
  for(i in 1:length(mods)){

    cur_mod <- mods[i]
    cur_terms <- subset(enrichr_df, module == cur_mod)
    print(cur_mod)

    # get color for this module:
    cur_color <- modules %>% subset(module == cur_mod) %>% .$color %>% unique %>% as.character
    if(!is.null(plot_bar_color)){
      cur_color <- plot_bar_color
    }

    # skip if there are not any terms for this module:
    if(nrow(cur_terms) == 0){next}
    cur_terms$wrap <- wrapText(cur_terms$Term, 45)

    # plot top n_terms as barplot
    plot_list <- list()
    for(cur_db in dbs){

      plot_df <- subset(cur_terms, db==cur_db) %>% 
        slice_max(order_by=Combined.Score, n=n_terms)

      # text color:
      if(is.null(plot_text_color)){
        if(cur_color == 'black'){
            text_color = 'grey'
          } else {
            text_color = 'black'
          }
      } else{
        text_color <- plot_text_color
      }
 

      # logscale?
      if(logscale){
        plot_df$Combined.Score <- log(plot_df$Combined.Score)
        lab <- 'Enrichment log(combined score)'
        x <- 0.2
      } else{lab <- 'Enrichment (combined score)'; x <- 5}

      # make bar plot:
      plot_list[[cur_db]] <- ggplot(plot_df, aes(x=Combined.Score, y=reorder(wrap, Combined.Score)))+
        geom_bar(stat='identity', position='identity', color='white', fill=cur_color) +
        geom_text(aes(label=wrap), x=x, color=text_color, size=3.5, hjust='left') +
        scale_x_continuous(expand = c(0, 0), limits = c(0, NA)) +
        xlab(lab) + ylab('') + ggtitle(cur_db) +
        theme(
          panel.grid.major=element_blank(),
          panel.grid.minor=element_blank(),
          legend.title = element_blank(),
          axis.ticks.y=element_blank(),
          axis.text.y=element_blank(),
          plot.title = element_text(hjust = 0.5),
          axis.line.y=element_blank()
        )
    }

    # make pdfs in output dir
    pdf(paste0(outdir, '/', cur_mod, '.pdf'), width=plot_size[1], height=plot_size[2])
    for(plot in plot_list){
      print(plot)
    }
    dev.off()
  }
}


#' EnrichrDotPlot
#'
#' Generate a dot plot visualizing enrichment results from Enrichr for hdWGCNA modules.
#'
#' This function creates dot plots from Enrichr results associated with hdWGCNA modules. 
#' Each module is represented by its most enriched terms from the specified Enrichr database. 
#' The size of the dots indicates the enrichment score, and the color indicates the statistical significance (-log10 transformed p-value).
#'
#' @param seurat_obj A Seurat object
#' @param database A character string specifying the name of the Enrichr database to use (e.g., "GO_Biological_Process_2021").
#' @param mods A character vector specifying the names of modules to include in the plot. 
#'        If `mods = "all"` (default), all modules except the "grey" module are included.
#' @param n_terms An integer specifying the number of top enriched terms to plot for each module (default = 3).
#' @param p_cutoff A numeric value specifying the p-value threshold for filtering enriched terms (default = 0.05).
#' @param p_adj A logical value indicating whether to use adjusted p-values (`TRUE`, default) or raw p-values (`FALSE`) for filtering.
#' @param break_ties A logical value indicating whether to randomly select among tied terms to enforce `n_terms` (default = `TRUE`).
#' @param term_size A numeric value specifying the font size of the enriched terms displayed on the y-axis (default = 10).
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @details
#' - The function first retrieves WGCNA module and Enrichr data from the specified Seurat object.
#' - Modules are filtered based on the `mods` parameter, and enriched terms are filtered by significance using `p_cutoff` and `p_adj`.
#' - The top `n_terms` terms for each module are selected based on the Combined Score. If ties occur, the `break_ties` parameter determines how they are resolved.
#' - A dot plot is generated where each dot represents an enriched term, its size corresponds to the Combined Score (log-transformed), 
#'   and its color indicates the significance (-log10 transformed p-value).
#'
#' @return A ggplot2 object representing the dot plot of enriched terms for the specified modules and database.
#' @export
EnrichrDotPlot <- function(
  seurat_obj, 
  database, 
  mods="all",
  n_terms = 3, 
  p_cutoff = 0.05,
  p_adj = TRUE,
  break_ties=TRUE,
  term_size=10,
  wgcna_name=NULL
){

  # get data from active assay if wgcna_name is not given
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get modules:
  modules <- GetModules(seurat_obj, wgcna_name)

  # using all modules?
  if(mods == 'all'){
    mods <- levels(modules$module)
    mods <- mods[mods != 'grey']
  }

  # get Enrichr table
  enrichr_df <- GetEnrichrTable(seurat_obj, wgcna_name)

  # subset based on significance level:
  if(p_adj){
    enrichr_df <- subset(enrichr_df, Adjusted.P.value <= p_cutoff)
  } else{
    enrichr_df <- subset(enrichr_df, P.value <= p_cutoff)
  }

  # add color to enrich_table
  mod_colors <- dplyr::select(modules, c(module, color)) %>% dplyr::distinct()
  enrichr_df$color <- mod_colors[match(enrichr_df$module, mod_colors$module), 'color']

  # helper function to wrap text
  wrapText <- function(x, len) {
      sapply(x, function(y) paste(strwrap(y, len), collapse = "\n"), USE.NAMES = FALSE)
  }

  # get data to plot
  top_terms <- enrichr_df %>%
    subset(db == database & module %in% mods) %>%
    group_by(module) %>%
    slice_max(order_by=Combined.Score, n=n_terms) %>% 
    .$Term

  plot_df <- subset(enrichr_df, Term %in% top_terms)

  # sometimes top_n returns more than the desired number if there are ties. so here
  # we just randomly sample to break ties:
  if(break_ties){
    plot_df <- do.call(rbind, lapply(plot_df %>% group_by(module) %>% group_split, function(x){x[sample(n_terms),]}))
  }

  plot_df <- plot_df %>% mutate(Term = stringr::str_replace(Term, " \\s*\\([^\\)]+\\)", "")) 
  plot_df$Term <- wrapText(plot_df$Term, 45)

  # set modules factor and re-order:
  plot_df$module <- factor(
    as.character(plot_df$module),
    levels=levels(modules$module)
  )
  plot_df <- arrange(plot_df, module)

  # set Terms factor
  plot_df$Term <- factor(
    as.character(plot_df$Term),
    levels = unique(as.character(plot_df$Term))
  )

  if(p_adj){
    plot_df$p <- plot_df$Adjusted.P.value
  } else{
    plot_df$p <- plot_df$P.value
  }

  max_p <- quantile(-log(plot_df$p), 0.95)

  plot_df$logp <- -log(plot_df$p)
  plot_df$logp <- ifelse(plot_df$logp > max_p, max_p, plot_df$logp)

  p <- plot_df  %>%
    ggplot(aes(x=module, y=Term, size=log10(Combined.Score), color=logp)) +
    geom_point() +
    #geom_point(aes(size=p), color=plot_df$color) +
    Seurat::RotatedAxis() +
    ylab('') + xlab('') + 
    labs(
        color = bquote("-log"[10]~"(P)"),
        size= bquote("log"[10]~"(Enrich)")
    ) + 
    scale_y_discrete(limits=rev) +
    ggtitle(database) +
    theme(
      plot.title = element_text(hjust = 0.5),
      axis.line.x = element_blank(),
      axis.line.y = element_blank(),
      axis.text.y = element_text(size=term_size),
      panel.border = element_rect(colour = "black", fill=NA, size=1),
      panel.grid = element_line(size=0.25, color='lightgrey')
    )

  p
}


#' ModuleNetworkPlot
#'
#' Visualizes the top hub genes for selected modules as a circular network plot with one 
#' inner circle and one outer circle.
#'
#' @param seurat_obj A Seurat object
#' @param n_inner Number of genes to plot on the inner circle 
#' @param n_outer Number of genes to plot on the outer circle.
#' @param n_conns Number of gene-gene co-expression connections to plot, sorted by the strongest connections.
#' @param mods Names of the modules to plot. If mods = "all", all modules are plotted.
#' @param outdir The directory where the plots will be stored.
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @param plot_size A vector containing the width and height of the network plots. example: c(5,5)
#' @param edge.alpha value between 0 and 1 determining the alpha (transparency) scaling factor for the network edges
#' @param edge.width value determining the width of the network edges
#' @param vertex.label.cex vertex label font size
#' @param vertex.size vertex size
#' @import igraph 
#' @export
ModuleNetworkPlot <- function(
  seurat_obj,
  n_inner = 10,
  n_outer = 15,
  n_conns = 500,
  mods="all",
  outdir="ModuleNetworks",
  wgcna_name=NULL,
  plot_size = c(6,6),
  edge.alpha=0.25,
  edge.width=1,
  vertex.label.cex=1,
  vertex.size=6, ...
){

  # get data from active assay if wgcna_name is not given
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get modules, MEs:
  modules <- GetModules(seurat_obj, wgcna_name)

  # using all modules?
  if(all('all' %in% mods)){
    mods <- levels(modules$module)
    mods <- mods[mods != 'grey']
  } else{

    # check that the modules are present 
    if(!all(mods %in% unique(as.character(modules$module)))){
      stop(paste0("Some selected modules are not found in wgcna_name: ", wgcna_name))
    }
    modules <- modules %>% subset(module %in% mods)
  }

  # check if we have eigengene-based connectivities:
  if(!all(paste0('kME_', as.character(mods)) %in% colnames(modules))){
    stop('Eigengene-based connectivity (kME) not found. Did you run ModuleEigengenes and ModuleConnectivity?')
  }

  # create output folder
  if(!dir.exists(outdir)){dir.create(outdir)}

  # tell the user that the output is going to the output dir
  cat(paste0("Writing output files to ", outdir))

  # get TOM
  TOM <- GetTOM(seurat_obj, wgcna_name)

  # get hub genes:
  n_hubs <- n_inner + n_outer
  hub_df <- GetHubGenes(seurat_obj, n_hubs=n_hubs, wgcna_name=wgcna_name)

  # loop over modules
  for(cur_mod in mods){
    cur_color <- modules %>% subset(module == cur_mod) %>% .$color %>% unique
    cur_genes <- subset(hub_df, module == cur_mod) %>% .$gene_name
    n_genes <- length(cur_genes)

    # skip if there's too few genes:
    if(n_genes < (n_inner+1)){
      print(paste0('Skipping ', cur_mod, ', too few genes to plot.'))
      next
    } 
    print(cur_mod)

    # name of column with current kME info
    cur_kME <- paste0('kME_', cur_mod)

    # Identify the columns in the TOM that correspond to these hub genes
    matchind <- match(cur_genes, colnames(TOM))
    reducedTOM <- TOM[matchind,matchind]
    orderind <- order(reducedTOM,decreasing=TRUE)

    if(n_conns > ncol(reducedTOM)**2){
      cur_n_conns <- ncol(reducedTOM)**2
    } else{
      cur_n_conns <- n_conns
    }

    # only  keep top connections
    connections2keep <- orderind[1:cur_n_conns];
    connections2drop <- orderind[n_conns+1:length(orderind)]
    reducedTOM[connections2drop] <- 0

    # scale between 0 and 1 
    reducedTOM <- scale01(reducedTOM)

    # melt TOM into long format
    edge_df <- reducedTOM %>% reshape2::melt()
    edge_df$color_alpha <- alpha(cur_color, alpha=edge_df$value)

    # top 10 as center
    gA <- graph.adjacency(as.matrix(reducedTOM[1:n_inner,1:n_inner]),mode="undirected",weighted=TRUE,diag=FALSE)
    gB <- graph.adjacency(as.matrix(reducedTOM[(n_inner + 1):n_genes,(n_inner+1):n_genes]),mode="undirected",weighted=TRUE,diag=FALSE)
    layoutCircle <- rbind(layout.circle(gA)/2,layout.circle(gB))

    g1 <- igraph::graph_from_data_frame(
        edge_df,
        directed=FALSE
      )

    pdf(paste0(outdir, '/', cur_mod,'.pdf'), width=plot_size[1], height=plot_size[2], useDingbats=FALSE);
    plot(g1,
      edge.color=adjustcolor(igraph::E(g1)$color_alpha, alpha.f=edge.alpha),
      edge.curved=0,
      edge.width=edge.width,
      #edge.color=adjustcolor(cur_color, alpha.f=0.25),
      #edge.alpha=edge.alpha,
      vertex.color=cur_color,
      vertex.label=as.character(cur_genes),
      vertex.label.dist=1.1,
      vertex.label.degree=-pi/4,
      vertex.label.color="black",
      vertex.label.family='Helvetica',
      vertex.label.font = 3,
      vertex.label.cex=vertex.label.cex,
      vertex.frame.color='black',
      layout= jitter(layoutCircle),
      vertex.size=vertex.size,
      main=paste(cur_mod)
    )
    dev.off();

  }

}



#' HubGeneNetworkPlot
#'
#' Construct a unified network plot comprising hub genes for multiple modules.
#'
#' @param seurat_obj A Seurat object
#' @param mods Names of the modules to plot. If mods = "all", all modules are plotted.
#' @param n_hubs The number of hub genes to plot for each module.
#' @param n_other The number of non-hub genes to sample from each module
#' @param edge_prop The proportion of edges in the graph to sample.
#' @param return_graph logical determining whether we return the graph (TRUE) or plot the graph (FALSE)
#' @param edge.alpha Scaling factor for the edge opacity
#' @param vertex.label.cex The font size of the gene labels
#' @param hub.vertex.size The size of the hub gene nodes
#' @param other.vertex.size The size of the other gene nodes
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' HubGeneNetworkPlot
HubGeneNetworkPlot <- function(
  seurat_obj, mods="all",
  n_hubs=6, n_other=3,
  sample_edges = TRUE,
  edge_prop = 0.5,
  return_graph=FALSE,
  edge.alpha=0.25,
  vertex.label.cex=0.5,
  hub.vertex.size=4,
  other.vertex.size=1,
  wgcna_name=NULL,
  ...
){

  # get data from active assay if wgcna_name is not given
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get modules, MEs:
  MEs <- GetMEs(seurat_obj, wgcna_name)
  modules <- GetModules(seurat_obj, wgcna_name)
  
  # using all modules?
  if(all('all' %in% mods)){
    mods <- levels(modules$module)
    mods <- mods[mods != 'grey']
  } else{

    # check that the modules are present 
    if(!all(mods %in% unique(as.character(modules$module)))){
      stop(paste0("Some selected modules are not found in wgcna_name: ", wgcna_name))
    }
    modules <- modules %>% subset(module %in% mods)
  }

  # get TOM
  TOM <- GetTOM(seurat_obj, wgcna_name)

  # get hub genes:
  hub_list <- lapply(mods, function(cur_mod){
    cur <- subset(modules, module == cur_mod)
    cur[,c('gene_name', paste0('kME_', cur_mod))] %>%
      top_n(n_hubs) %>% .$gene_name
  })
  names(hub_list) <- mods

  # sample the same number of genes in each module
  other_genes <- modules %>%
    subset(!(gene_name %in% unlist(hub_list))) %>%
    group_by(module) %>%
    sample_n(n_other, replace=TRUE) %>%
    .$gene_name %>% unique

  # subset TOM by the selected genes:
  selected_genes <- c(unlist(hub_list), other_genes)
  selected_modules <- modules %>% subset(gene_name %in% selected_genes)
  subset_TOM <- TOM[selected_genes, selected_genes]

  # setup for network plot
  selected_modules$geneset <- ifelse(
    selected_modules$gene_name %in% other_genes, 'other', 'hub'
  )
  selected_modules$size <- ifelse(selected_modules$geneset == 'hub', hub.vertex.size, other.vertex.size)
  selected_modules$label <- ifelse(selected_modules$geneset == 'hub', as.character(selected_modules$gene_name), '')
  selected_modules$fontcolor <- ifelse(selected_modules$color == 'black', 'gray50', 'black')

  # make sure all nodes have at least one edge!!
  edge_cutoff <- min(sapply(1:nrow(subset_TOM), function(i){max(subset_TOM[i,])}))
  edge_df <- reshape2::melt(subset_TOM) %>% subset(value >= edge_cutoff)

  edge_df$color <- future.apply::future_sapply(1:nrow(edge_df), function(i){
    gene1 = as.character(edge_df[i,'Var1'])
    gene2 = as.character(edge_df[i,'Var2'])

    col1 <- modules %>% subset(gene_name == gene1) %>% .$color
    col2 <- modules %>% subset(gene_name == gene2) %>% .$color

    if(col1 == col2){
      col = col1
    } else{
      col = 'grey90'
    }
    col
  })

  # subset edges:
  groups <- unique(edge_df$color)
  print(groups)
  if(sample_edges){

    # randomly sample
    temp <- do.call(rbind, lapply(groups, function(cur_group){
      cur_df <- edge_df %>% subset(color == cur_group)
      n_edges <- nrow(cur_df)
      cur_sample <- sample(1:n_edges, round(n_edges * edge_prop))
      cur_df[cur_sample,]
    }))
  } else{

    # get top strongest edges
    temp <- do.call(rbind, lapply(groups, function(cur_group){
      cur_df <- edge_df %>% subset(color == cur_group)
      n_edges <- nrow(cur_df)
      cur_df %>% dplyr::top_n(round(n_edges * edge_prop), wt=value)
    }))
  }

  edge_df <- temp

  # scale edge values between 0 and 1 for each module
  edge_df <- edge_df %>% group_by(color) %>% mutate(value=scale01(value))

  edge_df$color <- sapply(1:nrow(edge_df), function(i){
    a = edge_df$value[i]
    #if(edge_df$value[i] < 0.05){a=0.05}
    alpha(edge_df$color[i], alpha=a)
  })



  g <- igraph::graph_from_data_frame(
    edge_df,
    directed=FALSE,
    vertices=selected_modules
  )

  l <- igraph::layout_with_fr(g, ...)

  if(return_graph){return(g)}

  plot(
    g, layout=l,
    edge.color=adjustcolor(igraph::E(g)$color, alpha.f=edge.alpha),
    vertex.size=igraph::V(g)$size,
    edge.curved=0,
    edge.width=0.5,
    vertex.color=igraph::V(g)$color,
    vertex.frame.color=igraph::V(g)$color,
    vertex.label=igraph::V(g)$label,
    vertex.label.family='Helvetica', #vertex.label.font=vertex_df$font,
    vertex.label.font = 3,
    vertex.label.color = igraph::V(g)$fontcolor,
    vertex.label.cex=vertex.label.cex,
    ...
  )

}


#' ModuleUMAPPlot
#'
#' Makes a igraph network plot using the module UMAP
#'
#' @param seurat_obj A Seurat object
#' @param sample_edges logical determining whether we downsample edges for plotting (TRUE), or take the strongst edges.
#' @param edge_prop proportion of edges to plot. If sample_edges=FALSE, the strongest edges are selected.
#' @param label_hubs the number of hub genes to label in each module
#' @param edge.alpha scaling factor for edge opacity
#' @param vertex.label.cex font size for labeled genes
#' @param return_graph logical determining whether to plot thr graph (FALSE) or return the igraph object (TRUE)
#' @param keep_grey_edges logical determining whether to show edges between genes in different modules (grey edges)
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' ModuleUMAPPlot
ModuleUMAPPlot <- function(
  seurat_obj,
  sample_edges = TRUE, # TRUE if we sample edges randomly, FALSE if we take the top edges
  edge_prop = 0.2,
  label_hubs = 5, # how many hub genes to label?
  edge.alpha=0.25,
  vertex.label.cex=0.5,
  label_genes = NULL,
  return_graph = FALSE, # this returns the igraph object instead of plotting
  keep_grey_edges = TRUE,
  wgcna_name=NULL,
  ...
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get the TOM
  TOM <- GetTOM(seurat_obj, wgcna_name)

  # get modules,
  modules <- GetModules(seurat_obj, wgcna_name)

  # get the UMAP df:
  umap_df <- GetModuleUMAP(seurat_obj, wgcna_name)
  mods <- levels(umap_df$module)
  mods <- mods[mods != 'grey']

  # subset the TOM:
  subset_TOM <- TOM[umap_df$gene, umap_df$gene[umap_df$hub == 'hub']]

  # genes to label:
  # hub_labels <- selected_modules %>% group_by(module) %>% top_n(label_hubs, wt=kME) %>% .$gene_name
  hub_list <- lapply(mods, function(cur_mod){
    cur <- subset(modules, module == cur_mod)
    cur[,c('gene_name', paste0('kME_', cur_mod))] %>%
      top_n(label_hubs) %>% .$gene_name
  })
  names(hub_list) <- mods
  hub_labels <- as.character(unlist(hub_list))
  print('hub labels')
  print(hub_labels)
  print(label_genes)
  if(is.null(label_genes)){
    label_genes <- hub_labels
  } else{
    if(!any(label_genes %in% umap_df$gene)){
      stop("Some genes in label_genes not found in the UMAP.")
    }
    label_genes <- unique(c(label_genes, hub_labels))
  }
  print(label_genes)

  # subset module df by genes in the UMAP df:
  selected_modules <- modules[umap_df$gene,]
  selected_modules <- cbind(selected_modules, umap_df[,c('UMAP1', 'UMAP2', 'hub', 'kME')])

  selected_modules$label <- ifelse(selected_modules$gene_name %in% label_genes, selected_modules$gene_name, '')
  selected_modules$fontcolor <- ifelse(selected_modules$color == 'black', 'gray50', 'black')

  # set frome color
  # same color as module for all genes, black outline for the selected hub genes
  selected_modules$framecolor <- ifelse(selected_modules$gene_name %in% label_genes, 'black', selected_modules$color)

  # melt TOM into long format
  edge_df <- subset_TOM %>% reshape2::melt()
  print(dim(edge_df))

  # set color of each edge based on value:
  edge_df$color <- future.apply::future_sapply(1:nrow(edge_df), function(i){
    gene1 = as.character(edge_df[i,'Var1'])
    gene2 = as.character(edge_df[i,'Var2'])

    col1 <- selected_modules[selected_modules$gene_name == gene1, 'color']
    col2 <- selected_modules[selected_modules$gene_name == gene2, 'color']

    if(col1 == col2){
      col = col1
    } else{
      col = 'grey90'
    }
    col
  })

  # keep grey edges?
  if(!keep_grey_edges){
    edge_df <- edge_df %>% subset(color != 'grey90')
  }

  # subset edges:
  groups <- unique(edge_df$color)
  if(sample_edges){
    # randomly sample
    temp <- do.call(rbind, lapply(groups, function(cur_group){
      cur_df <- edge_df %>% subset(color == cur_group)
      n_edges <- nrow(cur_df)
      cur_sample <- sample(1:n_edges, round(n_edges * edge_prop))
      cur_df[cur_sample,]
    }))
  } else{

    # get top strongest edges
    temp <- do.call(rbind, lapply(groups, function(cur_group){
      cur_df <- edge_df %>% subset(color == cur_group)
      n_edges <- nrow(cur_df)
      cur_df %>% dplyr::top_n(round(n_edges * edge_prop), wt=value)
    }))
  }

  edge_df <- temp
  print(dim(edge_df))

  # scale edge values between 0 and 1 for each module
  edge_df <- edge_df %>% group_by(color) %>% mutate(value=scale01(value))

  # edges & vertices are plotted in igraph starting with the first row, so re-order s.t. strong edges are on bottom, all gray on the top of the table:
  edge_df <- edge_df %>% arrange(value)
  edge_df <- rbind(
    subset(edge_df, color == 'grey90'),
    subset(edge_df, color != 'grey90')
  )

  # set alpha of edges based on kME
  edge_df$color_alpha <- ifelse(
    edge_df$color == 'grey90',
    alpha(edge_df$color, alpha=edge_df$value/2),
    alpha(edge_df$color, alpha=edge_df$value)
  )

  # re-order vertices so hubs are plotted on top
  selected_modules <- rbind(
    subset(selected_modules , hub == 'other'),
    subset(selected_modules , hub != 'other')
  )

  # re-order vertices so labeled genes are on top
  selected_modules <- rbind(
    subset(selected_modules , label == ''),
    subset(selected_modules , label != '')
  )

  # setup igraph:
  g <- igraph::graph_from_data_frame(
    edge_df,
    directed=FALSE,
    vertices=selected_modules
  )

  # print('making net')
  # print(head(edge_df))
  # print(head(selected_modules))

  if(return_graph){return(g)}

  plot(
    g,
    layout= as.matrix(selected_modules[,c('UMAP1', 'UMAP2')]),
    # edge.color=adjustcolor(igraph::E(g)$color, alpha.f=edge.alpha),
    edge.color=adjustcolor(igraph::E(g)$color_alpha, alpha.f=edge.alpha),
    vertex.size=igraph::V(g)$kME * 3,
    edge.curved=0,
    edge.width=0.5,
    vertex.color=igraph::V(g)$color,
    vertex.label=igraph::V(g)$label,
    vertex.label.dist=1.1,
    vertex.label.degree=-pi/4,
    vertex.label.family='Helvetica', #vertex.label.font=vertex_df$font,
    vertex.label.font = 3,
    vertex.label.color = igraph::V(g)$fontcolor,
    vertex.label.cex=0,
    vertex.frame.color=igraph::V(g)$framecolor,
    margin=0
  )

}


#' OverlapDotPlot
#'
#' Makes barplots from Enrichr data
#'
#' @param overlap_df the Module/DEG overlap table from OverlapModulesDEGs
#' @param plot_var the name of the overlap statistic to plot
#' @param logscale logical controlling whether to plot the result on a log scale, useful for odds ratio
#' @param neglog logical controlling wehether to plot the result as a negative log, useful for p-value / FDR
#' @param plot_significance logical controlling whether to plot the significance levels on top of the dots
#' @keywords scRNA-seq
#' @export
#' @examples
#' OverlapDotPlot
OverlapDotPlot <- function(
  overlap_df, plot_var = 'odds_ratio',
  logscale=TRUE,
  neglog=FALSE,
  plot_significance=TRUE,
  ...
){

  label <- plot_var
  if(logscale){
    overlap_df[[plot_var]] <- log(overlap_df[[plot_var]])
    label <- paste0('log(', plot_var, ')')
  }
  if(neglog){
    overlap_df[[plot_var]] <- -1 * overlap_df[[plot_var]]
    label <- paste0('-', label)
  }

  p <- overlap_df %>% ggplot(aes(x=module, y=group)) +
    geom_point(aes(
      size=get(plot_var),
      alpha=get(plot_var)),
      color=overlap_df$color
    ) +
    RotatedAxis() +
    ylab('') + xlab('') + labs(size=label, alpha=label) +
    theme(
      plot.title = element_text(hjust = 0.5),
      axis.line.x = element_blank(),
      axis.line.y = element_blank(),
      panel.border = element_rect(colour = "black", fill=NA, size=1)
    )

  # plot significance level?
  if(plot_significance){
    p <- p + geom_text(aes(label=Significance))
  }

  p
}

#' OverlapBarPlot
#'
#' Plots the results from OverlapModulesDEGs as a bar plot
#'
#' @param overlap_df the Module/DEG overlap table from OverlapModulesDEGs
#' @param plot_var the name of the overlap statistic to plot
#' @param logscale logical controlling whether to plot the result on a log scale, useful for odds ratio
#' @param neglog logical controlling wehether to plot the result as a negative log, useful for p-value / FDR
#' @param label_size the size of the module labels in the bar plot
#' @keywords scRNA-seq
#' @export
#' @examples
#' OverlapBarPlot
OverlapBarPlot <- function(
  overlap_df,
  plot_var = 'odds_ratio',
  logscale=FALSE, neglog=FALSE,
  label_size=2,
  ...
){

  label <- plot_var

  if(plot_var == 'odds_ratio'){
    yint <- 1
  } else if(plot_var == 'fdr'){
    yint <- 0.05
  }

  if(logscale){
    overlap_df[[plot_var]] <- log(overlap_df[[plot_var]])
    label <- paste0('log(', plot_var, ')')
    yint = log(yint)
  }
  if(neglog){
    overlap_df[[plot_var]] <- -1 * log(overlap_df[[plot_var]])
    label <- paste0('-log(', label, ')')
    yint = -1 * log(yint)
  }

  groups <- overlap_df$group %>% as.character %>% unique

  plot_list <- list()
  for(cur_group in groups){

    cur_df <- overlap_df %>%
      subset(group == cur_group)

    p <- cur_df %>%
      ggplot(aes(x=reorder(module, get(plot_var)), y=get(plot_var))) +
      geom_bar(stat='identity', fill=cur_df$color) +
      coord_flip() +
      xlab('') + ylab(label) +
      ggtitle(cur_group) +
      theme(
        axis.line.y=element_blank(),
        axis.ticks.y=element_blank(),
        axis.text.y = element_blank(),
        plot.title = element_text(hjust = 0.5)
      )

      if(plot_var == 'fdr' | plot_var == 'odds_ratio'){
        p <- p + geom_hline(yintercept=yint, linetype='dashed', color='gray')
      }

      # add the labels:
      p <- p +
        geom_text(
          aes(label=module, x=module, y=get(plot_var)), color='black', size=label_size, hjust='inward'
        )

      plot_list[[cur_group]] <- p
  }

  plot_list

}

#' Displays the top n TFs in a set of modules as a bar plot
#'
#' @param seurat_obj A Seurat object
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' MotifOverlapBarPlot
MotifOverlapBarPlot <- function(
  seurat_obj,
  n_tfs = 10,
  plot_size = c(5,6),
  outdir = "MotifOverlaps/",
  motif_font = 'helvetica_regular',
  module_names = NULL,
  wgcna_name=NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # make output dir if it doesn't exist:
  if(!dir.exists(outdir)){dir.create(outdir)}

  # get Modules
  modules <- GetModules(seurat_obj)
  mods <- levels(modules$module)
  mods <- mods[mods != 'grey']

  if(is.null(module_names)){module_names <- mods}

  # get overlap info from Seurat obj
  overlap_df <- GetMotifOverlap(seurat_obj, wgcna_name)
  motif_df <- GetMotifs(seurat_obj)

  # get pfm list from Seurat obj
  pfm <- GetPFMList(seurat_obj)

  # add motif ID to the overlap_df
  overlap_df$motif_ID <- motif_df$motif_ID[match(overlap_df$tf, motif_df$motif_name)]

  # subset by the modules that we are using:
  overlap_df <- overlap_df %>% subset(module %in% module_names)

  for(cur_mod in module_names){
    print(cur_mod)
    # get df for cur mod
    plot_df <- overlap_df %>% subset(module == cur_mod) %>% top_n(n_tfs, wt=odds_ratio) %>%
      arrange(desc(odds_ratio))

    # make the barplot
    p1 <- plot_df %>% ggplot(aes(y=reorder(tf, odds_ratio), fill=odds_ratio, x=odds_ratio)) +
      geom_bar(stat='identity', width=0.7) + NoLegend() +
      scale_fill_gradient(high=unique(plot_df$color), low='grey90') +
      ylab('') + theme(
        axis.line.y = element_blank(),
        axis.text.y = element_blank(),
        plot.margin = margin(
          t = 0, r = 0, b = 0, l = 0
        )
      )

    # make the motif logo plots:
    plot_list <- list()
    for(i in 1:nrow(plot_df)){
      cur_id <- plot_df[i,'motif_ID']
      cur_name <- plot_df[i,'tf']
      plot_list[[cur_id]] <- ggplot() +
        ggseqlogo::geom_logo( as.matrix(pfm[[cur_id]]), font=motif_font) +
        ggseqlogo::theme_logo() +
        xlab('') + ylab(cur_name) + theme(
          axis.text.x=element_blank(),
          axis.text.y=element_blank(),
          axis.title.y = element_text(angle=0),
          plot.margin = margin(t = 0,  # Top margin
                                 r = 0,  # Right margin
                                 b = 0,  # Bottom margin
                                 l = 0) # Left margin
        )
    }

    # wrap the motif logos
    patch1 <- wrap_plots(plot_list, ncol=1)

    # assemble the plot with patchwork
    outplot <- (patch1 | p1) +
          plot_layout(ncol=2, widths=c(1,2)) +
          plot_annotation(title=paste0('Motif overlaps with ', cur_mod),
          theme = theme(plot.title=element_text(hjust=0.5))
        )

    pdf(paste0(outdir, '/', cur_mod, '_motif_overlaps.pdf'), width=plot_size[1], height=plot_size[2], useDingbats=FALSE)
    print(outplot)
    dev.off()

  }

}



#' Plots gene expression of hub genes as a heatmap
#'
#' This function makes an expression heatmap of the top n hub genes per module
#' using Seurat's DoHeatmap, and then assembles them all into one big heatmap.
#'
#'
#'
#' @param seurat_obj A Seurat object
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' DoHubGeneHeatmap
DoHubGeneHeatmap <- function(
  seurat_obj,
  n_hubs = 10,
  n_cells = 200,
  group.by=NULL,
  module_names = NULL,
  combine=TRUE, #returns a list of individual heatmaps if FALSE
  draw.lines=TRUE,
  disp.min = -2.5, disp.max = 2.5, # cutoff expression values
  wgcna_name=NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # use idents as grouping variable if not specified
  if(is.null(group.by)){
    group.by <- 'temp_ident'
    seurat_obj$temp_ident <- Idents(seurat_obj)
  }

  # drop if there are missing levels:
  seurat_obj@meta.data[[group.by]] <- droplevels(seurat_obj@meta.data[[group.by]])

  # get modules
  modules <- GetModules(seurat_obj, wgcna_name)
  modules <- modules %>% subset(module != 'grey') %>% mutate(module = droplevels(module))
  mods <- levels(modules$module)

  if(!is.null(module_names)){
    print('here')
    mods <- module_names
    modules <- modules %>% subset(module %in% mods)
  }

  # get table of module names & colors
  mod_colors <- modules %>% dplyr::select(c(module, color)) %>% dplyr::distinct()

  # get hub genes:
  # hub_list <- lapply(mods, function(cur_mod){
  #   cur <- subset(modules, module == cur_mod)
  #   cur[,c('gene_name', paste0('kME_', cur_mod))] %>%
  #     top_n(n_hubs) %>% .$gene_name
  # })
  hub_list <- lapply(mods, function(cur_mod){
    cur <- subset(modules, module == cur_mod)
    cur <- cur[,c('gene_name', paste0('kME_', cur_mod))] %>%
      top_n(n_hubs)
    colnames(cur)[2] <- 'var'
    cur %>% arrange(desc(var)) %>% .$gene_name
  })
  names(hub_list) <- mods
  print(hub_list)

  seurat_obj$barcode <- colnames(seurat_obj)
  temp <- table(seurat_obj@meta.data[[group.by]])

  # sample cells
  df <- data.frame()
  for(i in 1:length(temp)){

    if(temp[[i]] < n_cells){
      cur_df <- seurat_obj@meta.data %>% subset(get(group.by) == names(temp)[i])
    } else{
      cur_df <- seurat_obj@meta.data %>% subset(get(group.by) == names(temp)[i]) %>% sample_n(n_cells);
    }
    df <- rbind(df, cur_df)
  }

  # make sampled seurat obj for plotting:
  seurat_plot <- seurat_obj %>% subset(barcode %in% df$barcode)

  plot_list <- list()
  for(i in 1:length(hub_list)){

    print(i)
    cur_mod <- names(hub_list)[i]
    print(i)
    print(hub_list[[i]])
    print(i)

    if(i == 1){
      plot_list[[i]] <- DoHeatmap(
        seurat_plot,
        features = hub_list[[i]],
        group.by=group.by,
        raster=TRUE, slot='scale.data',
        disp.min = disp.min, disp.max=disp.max,
        label=FALSE,
        group.bar=FALSE,
        draw.lines=draw.lines
      )
    } else{
      plot_list[[i]] <- DoHeatmap(
       seurat_plot,
       features=hub_list[[i]],
       group.by=group.by,
       raster=TRUE, slot='scale.data',
       group.bar.height=0,
       label=FALSE, group.bar=FALSE,
       draw.lines=draw.lines,
       disp.min = disp.min, disp.max=disp.max
     ) + NoLegend()
    }
    print(i)
    # margin:
    plot_list[[i]] <- plot_list[[i]] +
      theme(
        plot.margin = margin(0,0,0,0),
        axis.text.y = element_text(face='italic')
      )  + scale_y_discrete(position = "right")
    print(i)

  }

  # useful for ratios on colorbars
  n_total_cells <- ncol(seurat_plot)
  width_cbar <- n_total_cells / 50


  # module colorbar
  mod_colors$value <- n_hubs
  mod_colors$dummy <- 'colorbar'
  cbar_list <- list()
  for(i in 1:nrow(mod_colors)){
    cbar_list[[i]] <- mod_colors[i,] %>% ggplot(aes(y=value, x=dummy)) +
      geom_bar(position='stack', stat='identity', fill=mod_colors[i,]$color) +
      umap_theme() + theme(
        plot.margin=margin(0,0,0,0)
      )
  }
  p_cbar <- wrap_plots(cbar_list, ncol=1)

  if(combine){
    out <- wrap_plots(plot_list, ncol=1) +plot_layout(guides='collect')
    out <- (p_cbar | out) + plot_layout(widths=c(width_cbar, n_total_cells))
  } else{
    out <- plot_list
  }

  out

}

#' PlotModulePreservation
#'
#' Plotting function for Module Preservation statistics
#'
#' @param seurat_obj A Seurat object
#' @param name The name of the module preservation analysis to plot given in ModulePreservation
#' @param statistics Which module preservation statistics to plot? Choices are summary, all, or a custom list
#' @param plot_labels logical determining whether to plot the module labels
#' @param label_size the size of the module labels
#' @param mod_point_size the size of the points in each plot
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' 
#' @details
#' This function creates a scatter plot showing the module preservation statistics for each module 
#' compared to the size of the module (number of genes). 
#' 
#' @export
PlotModulePreservation <- function(
  seurat_obj,
  name,
  statistics = 'summary', # can be summary, rank, all, or a custom list
  plot_labels = TRUE,
  label_size = 4,
  mod_point_size = 4,
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get the module preservation stats:
  mod_pres <- GetModulePreservation(seurat_obj, name, wgcna_name)
  obs_df <- mod_pres$obs
  Z_df <- mod_pres$Z

  # get module colors:
  modules <- GetModules(seurat_obj, wgcna_name)
  module_colors <- modules %>% dplyr::select(c(module, color)) %>% dplyr::distinct()
  mods <- rownames(Z_df)
  mod_colors <- module_colors$color[match(mods, module_colors$module)]
  mod_colors = ifelse(is.na(mod_colors), 'gold', mod_colors)

  # what are we going to plot?
  if(statistics == 'summary'){
    stat_list <- c("Zsummary.qual", "Zsummary.pres")
  } else if(statistics == 'rank'){
    stat_list <- colnames(obs_df[,-1])[grepl("Rank", colnames(obs_df[,-1]))]
  } else if(statistics == 'all'){
    stat_list <- c(colnames(obs_df[,-1])[grepl("Rank", colnames(obs_df[,-1]))], colnames(Z_df[,-1]))
  } else{
    stat_list <- statistics
  }

  stat_list <- stat_list[stat_list != 'moduleSize']

  plot_list <- list()
  for(statistic in stat_list){

    print(statistic)

    if(statistic %in% colnames(obs_df)){
      values <- obs_df[,statistic]
    } else if(statistic %in% colnames(Z_df)){
      values <- Z_df[,statistic]
    } else{
      stop("Invalid name for statistic.")
    }

    # setup plotting df
    plot_df <- data.frame(
      module = mods,
      color = mod_colors,
      value = values,
      size = Z_df$moduleSize
    )

    # don't include grey & gold:
    plot_df <- plot_df %>% subset(!(module %in% c('grey', 'gold')))

    if(grepl("Rank", statistic)){
      cur_p <-  plot_df %>% ggplot(aes(x=size, y=value, fill=module, color=module)) +
        geom_point(size=mod_point_size, pch=21, color='black') +
        scale_y_reverse()
    } else{
      cur_p <- plot_df %>% ggplot(aes(x=size, y=value, fill=module, color=module)) +
        geom_rect(
          data = plot_df[1,],
          aes(xmin=0, xmax=Inf, ymin=-Inf, ymax=2), fill='grey75', alpha=0.8, color=NA) +
        geom_rect(
          data=plot_df[1,],
          aes(xmin=0, xmax=Inf, ymin=2, ymax=10), fill='grey92', alpha=0.8, color=NA) +
        geom_point(size=mod_point_size, pch=21, color='black')
    }

    cur_p <- cur_p +
      scale_fill_manual(values=plot_df$color) +
      scale_color_manual(values=plot_df$color) +
      scale_x_continuous(trans='log10') +
      ylab(statistic) +
      xlab("Module Size") +
      ggtitle(statistic) +
      NoLegend() +
      theme(
        plot.title = element_text(hjust = 0.5)
      )


    if(plot_labels){
      cur_p <- cur_p + ggrepel::geom_text_repel(label = plot_df$module, size=label_size, max.overlaps=Inf, color='black')
    }

    plot_list[[statistic]] <- cur_p

  }

  if(length(plot_list) == 1){return(plot_list[[1]])}

  plot_list

}

#' PlotModuleTraitCorrelation
#'
#' Plotting function for Module Preservation statistics
#'
#' @param seurat_obj A Seurat object
#' @param high_color color for positive correlation
#' @param mid_color color for zero correlation
#' @param low_color color for negative correlation
#' @param label logical determining whether to add p-val label in each cell of the heatmap
#' @param label_symbol show the labels as 'stars' or as 'numeric'
#' @param plot_max maximum value of correlation to show on the colorbar
#' @param text_size size of the labels
#' @param text_color color of the text labels
#' @param text_digits how many digits to show in the text labels
#' @param combine logical determining whether to plot as one combined plot (TRUE) or to return individual plots as a list (FALSE)
#' @keywords scRNA-seq
#' @export
PlotModuleTraitCorrelation <- function(
  seurat_obj,
  high_color = 'red',
  mid_color = 'grey90',
  low_color = 'blue',
  label = NULL,
  label_symbol = 'stars',
  plot_max = NULL,
  text_size = 2,
  text_color = 'black',
  text_digits = 3,
  combine = TRUE,
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get the module trait correlation results:
  temp <- GetModuleTraitCorrelation(seurat_obj, wgcna_name)
  cor_list <- temp$cor
  pval_list <- temp$pval
  fdr_list <- temp$fdr

  if(is.null(dim(cor_list[[1]]))){
    stop('ModuleTraitCorrelation was run only for one trait. Heatmaps are not suggested for visualizing only one variable!')
  }

  # get module colors:
  modules <- GetModules(seurat_obj, wgcna_name)
  module_colors <- modules %>%
    dplyr::select(c(module, color)) %>%
    dplyr::distinct() %>% subset(module != 'grey') %>%
    dplyr::arrange(module)
  mod_colors <- module_colors$color

  # dummy variable
  module_colors$var <- 1

  # make the colorbar as its own heatmap
  module_colorbar <- module_colors %>%
    ggplot(aes(x=module, y=var, fill=module)) +
    geom_tile() +
    scale_fill_manual(values=mod_colors) +
    NoLegend() +
    RotatedAxis() +
    theme(
      plot.title=element_blank(),
      axis.line=element_blank(),
      axis.ticks.y=element_blank(),
      axis.text.y = element_blank(),
      axis.title = element_blank(),
      plot.margin=margin(0,0,0,0)
    )


  plot_list <- list()
  for(i in names(cor_list)){
    cor_mat <- as.matrix(cor_list[[i]])
    pval_mat <- as.matrix(pval_list[[i]])
    fdr_mat <- as.matrix(fdr_list[[i]])
    print(i)

    plot_df <- reshape2::melt(cor_mat)
    colnames(plot_df) <- c("Trait", "Module", "cor")

    #p_df <- reshape2::melt(pval_mat)
    if(!is.null(label)){
      if(label == 'fdr'){
        p_df <- reshape2::melt(fdr_mat)
      } else if(label == 'pval'){
        p_df <- reshape2::melt(pval_mat)
      }
      colnames(p_df) <- c("Trait", "Module", "pval")

      # add pval to plot_df
      plot_df$pval <- p_df$pval
      print(levels(plot_df$Trait))

      if(label_symbol == 'stars'){
        plot_df$significance <- gtools::stars.pval(plot_df$pval)
      } else if(label_symbol == 'numeric'){
        plot_df$significance <- ifelse(
          plot_df$pval <= 0.05,
          formatC(plot_df$pval, digits=text_digits), ''
        )
      } else{
        stop('Invalid input for label_symbol. Valid choices are stars or numeric.')
      }
    }

    # get limits for plot:
    if(is.null(plot_max)){
      max_plot <- max(abs(range(plot_df$cor)))
    } else{
      max_plot <- plot_max

      # fix values outside of the specified range:
      plot_df$cor <- ifelse(abs(plot_df$cor) >= plot_max, plot_max * sign(plot_df$cor), plot_df$cor)
    }

    p <- ggplot(plot_df, aes(x=Module, y=as.numeric(Trait), fill=cor)) +
      geom_tile() +
      scale_fill_gradient2(
        limits=c(-1*max_plot,max_plot),
        high=high_color,
        mid=mid_color,
        low=low_color,
        guide = guide_colorbar(ticks=FALSE, barwidth=16, barheight=0.5)
      ) +
      scale_y_continuous(
        breaks = 1:length(levels(plot_df$Trait)),
        labels=levels(plot_df$Trait),
        sec.axis = sec_axis(
          ~.,
          breaks = 1:length(levels(plot_df$Trait)),
          labels=levels(plot_df$Trait)
        )
      ) +
      RotatedAxis() + ylab('') + xlab('') + ggtitle(i) +
      # labs(fill = 'Correlation') +
      theme(
        plot.title=element_text(hjust=0.5),
        axis.line=element_blank(),
        axis.ticks.y=element_blank(),
        axis.text.y.left = element_blank(),
        axis.title.y = element_text(angle=0, vjust=0.5),
        legend.title = element_blank(),
        legend.position='bottom'
      )

    if(!is.null(label)){
      p <- p + geom_text(label=plot_df$significance, color=text_color, size=text_size)
    }

    plot_list[[i]] <- p

  }

  if(combine){

    #plot_list <- c(plot_list, cbar)
    #names(plot_list)[length(plot_list)] <- 'module'

    for(i in 1:length(plot_list)){

      plot_list[[i]] <- plot_list[[i]] +
        ylab(names(plot_list)[i]) +
        theme(
          plot.margin = margin(t = 0, r = 0, b = 0, l = 0),
          axis.title.x = element_blank(),
          plot.title = element_blank(),
          legend.position='bottom',
          axis.text.x = element_blank(),
          axis.ticks=element_blank(),
          axis.title.y = element_text(angle=0, vjust=0.5)
        )
    }

    # assemble with patchwork:
    plot_list[['module']] <- module_colorbar

    out <- wrap_plots(plot_list, ncol=1) +
      plot_layout(
        guides = 'collect',
        heights = c(rep(1, length(plot_list)-1), 0.15)
      ) +
      plot_annotation(
        theme=theme(
          plot.title=element_text(hjust=0.5),
          legend.position = 'bottom',
          legend.justification = 0.5
        )
      )

    return(out)

  } else{
    return(plot_list)
  }

}




#' PlotKMEs
#'
#' Plotting function to show genes by kME value in each module
#'
#' @param seurat_obj A Seurat object
#' @param n_hubs number of hub genes to display
#' @param text_size controls the size of the hub gene text
#' @param ncol number of columns to display individual plots
#' @param plot_widths the relative width between the kME rank plot and the hub gene text
#' @param wgcna_name the name of the WGCNA experiment in the seurat object
#' @keywords scRNA-seq
#' @export
PlotKMEs <- function(
  seurat_obj,
  n_hubs=10,
  text_size=2,
  ncol = 5,
  plot_widths = c(3,2),
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  modules <- GetModules(seurat_obj, wgcna_name) %>% subset(module != 'grey')
  mods <- levels(modules$module); mods <- mods[mods != 'grey']
  mod_colors <- modules %>% subset(module %in% mods) %>%
    dplyr::select(c(module, color)) %>%
    dplyr::distinct()

  # get hub genes:
  hub_df <- GetHubGenes(seurat_obj, n_hubs=n_hubs, wgcna_name=wgcna_name)

  plot_list <- lapply(mods, function(x){
    cur_color <- subset(mod_colors, module == x) %>% .$color
    cur_df <- subset(hub_df, module == x)
    top_genes <- cur_df %>% dplyr::top_n(n_hubs, wt=kME) %>% .$gene_name
    p <- cur_df %>% ggplot(aes(x = reorder(gene_name, kME), y = kME)) +
      geom_bar(stat='identity', width=1, color = cur_color, fill=cur_color) +
      ggtitle(x) +
      #xlab(paste0('kME_', x)) +
      theme(
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
        plot.title = element_text(hjust=0.5),
        axis.title.x = element_blank(),
        axis.line.x = element_blank()
      )
    p_anno <- ggplot() + annotate(
      "label",
      x = 0,
      y = 0,
      label = paste0(top_genes, collapse="\n"),
      size=text_size,
      fontface = 'italic',
      label.size=0
    ) + theme_void()
    patch <- p + p_anno + plot_layout(widths=plot_widths)
    patch
  })

  wrap_plots(plot_list, ncol=ncol)

}



#' PlotDMEsVolcano
#'
#' Plotting function for the results of FindDMEs and FindAllDMEs.
#'
#' This function generates a volcano plot for differential module expression (DME) analysis results. 
#' It can handle both two-group comparisons (using the output of `FindDMEs`) and one-vs-all comparisons 
#' (using the output of `FindAllDMEs`).
#'
#' @param seurat_obj A Seurat object containing the WGCNA analysis in the @misc slot.
#' @param DMEs A dataframe output from FindDMEs or FindAllDMEs containing DME results.
#' @param plot_labels Logical, determines whether to plot the module labels on the volcano plot. Default is TRUE.
#' @param label_size Numeric, the size of the module labels on the plot. Default is 4.
#' @param mod_point_size Numeric, the size of the points on the volcano plot. Default is 4.
#' @param show_cutoff Logical, determines whether to plot the significance cutoff. Set this to FALSE if using facet_wrap. Default is TRUE.
#' @param wgcna_name Character, the name of the hdWGCNA experiment in the `seurat_obj@misc` slot. Default is NULL, in which case it pulls the active WGCNA experiment from `seurat_obj@misc$active_wgcna`.
#' @param xlim_range A numeric vector of length 2 specifying the x-axis limits for the log2 fold change. Default is NULL, which automatically calculates limits based on the data.
#' @param ylim_range A numeric vector of length 2 specifying the y-axis limits for the -log10(p-value). Default is NULL, which automatically calculates limits based on the data.
#' @keywords scRNA-seq, volcano plot, differential expression, WGCNA
#' @export
#' @return A ggplot object containing the volcano plot for the DME results.
#' @examples
#' # Example usage:
#' # Assuming `seurat_obj` is your Seurat object and `DMEs` is the output from FindDMEs
#' PlotDMEsVolcano(seurat_obj, DMEs, wgcna_name = "MG")
PlotDMEsVolcano <- function(
  seurat_obj,
  DMEs,
  plot_labels=TRUE,
  mod_point_size=4,
  label_size=4,
  show_cutoff = TRUE,
  wgcna_name=NULL,
  xlim_range=NULL,  # New parameter to control x-axis limits
  ylim_range=NULL   # New parameter to control y-axis limits
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # remove NAs:
  DMEs <- na.omit(DMEs)

  # lowest non-zero value
  lowest <- DMEs %>% subset(p_val_adj != 0) %>% top_n(-1, wt=p_val_adj) %>% .$p_val_adj
  DMEs$p_val_adj <- ifelse(DMEs$p_val_adj == 0, lowest, DMEs$p_val_adj)

  # fix infinite fold change
  max_fc <- max(abs(DMEs$avg_log2FC))
  max_fc <- DMEs %>% subset(abs(avg_log2FC) != Inf) %>% .$avg_log2FC %>% max
  DMEs$avg_log2FC <- ifelse(DMEs$avg_log2FC == -Inf, -1*max_fc, DMEs$avg_log2FC)
  DMEs$avg_log2FC <- ifelse(DMEs$avg_log2FC == Inf, max_fc, DMEs$avg_log2FC)

  # get modules and module colors
  modules <- GetModules(seurat_obj, wgcna_name) %>% subset(module != 'grey') %>% mutate(module=droplevels(module))
  module_colors <- modules %>% dplyr::select(c(module, color)) %>% dplyr::distinct()

  # module names
  mods <- levels(modules$module)
  mods <- mods[mods %in% DMEs$module]
  mod_colors <- module_colors$color; names(mod_colors) <- as.character(module_colors$module)

  # annotate modules with significant DME
  DMEs$anno <- ifelse(DMEs$p_val_adj < 0.05, DMEs$module, '')

  # set default x-axis limit if not provided
  if(is.null(xlim_range)){
    xmax <- max_fc
    xlim_range <- c((-1*xmax)-0.1, xmax+0.1)  # Adjust as needed
  }

  # set default y-axis limit if not provided
  if(is.null(ylim_range)){
    ymax <- max(-log10(DMEs$p_val_adj))
    ylim_range <- c(0, ymax + 1)  # Adjust as needed
  }

  # plot basics
  p <- DMEs %>%
    ggplot(aes(x=avg_log2FC, y=-log10(p_val_adj), fill=module, color=module))


  if(show_cutoff){
    p <- p +
      geom_vline(xintercept=0, linetype='dashed', color='grey75', alpha=0.8) +
      geom_rect(
        data=DMEs[1,],
        aes(xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=-log10(0.05)), fill='grey75', alpha=0.8, color=NA)
  }

  # add points:
  p <- p + geom_point(size=mod_point_size, pch=21, color='black')

  # label points?
  if(plot_labels){
    p <- p + ggrepel::geom_text_repel(aes(label=anno), color='black', min.segment.length=0, max.overlaps=Inf, size=label_size)
  }

  p <- p +
    scale_fill_manual(values=mod_colors) +
    scale_color_manual(values=mod_colors) +
    xlim(xlim_range) +  # Apply new xlim range
    ylim(ylim_range) +  # Apply new ylim range
    xlab(bquote("Average log"[2]~"(Fold Change)")) +
    ylab(bquote("-log"[10]~"(Adj. P-value)")) +
    theme(
     panel.border = element_rect(color='black', fill=NA, size=1),
     panel.grid.major = element_blank(),
     axis.line = element_blank(),
     plot.title = element_text(hjust = 0.5),
     legend.position='bottom'
   ) + NoLegend()

   return(p)
}


             
             
             
#' PlotDMEsLollipop
#'
#' Plotting function for the results of FindDMEs and FindAllDMEs
#'
#' @param seurat_obj A Seurat object
#' @param DMEs dataframe output from FindDMEs or FindAllDMEs
#' @param group.by the column name of the selected comparison in the DMEs dataframe
#' @param comparison character vector or a list of character vectors containing the comparison in the DMEs dataframe
#' @param pvalue p_value or fdr used for the comparison in the DMEs dataframe
#' @keywords scRNA-seq
#' @export
#' @return A ggplot object
#' PlotDMEsLollipop
PlotDMEsLollipop <- function(
  seurat_obj,
  DMEs,
  wgcna_name,
  group.by=NULL,
  comparison= NULL,
  pvalue,
  avg_log2FC = 'avg_log2FC'
){

  if (!require("ggforestplot")) {
    print('Missing package: ggforestplot')
    print('Installing package: ggforestplot')
    devtools::install_github("NightingaleHealth/ggforestplot")
  }

  if(!(pvalue %in% colnames(DMEs))){
  stop('Selected pvalue is not found in DMEs dataframe column names.')
  }

  if(missing(wgcna_name) || !(wgcna_name %in% names(seurat_obj@misc))){
  stop('Please provide wgcna_name or the selected wgcna_name is not found in seurat_obj@misc.')
  }

  modules <- GetModules(seurat_obj, wgcna_name) %>% 
    subset(module != 'grey') %>% 
    mutate(module=droplevels(module))

  if (!missing(group.by) & !missing(comparison)) {
    comparisons <- comparison
    if(!(all(comparisons %in% DMEs[[group.by]]))){
    stop('Not all selected comparisons are not found in DMEs[[group.by]] or the comparison column, DMEs[[group.by]], is not correctly supplied.')
    }

    # comparisons <- unique(DMEs$comparison)
    plot_list <- list()

    for(cur_comp in comparisons){

        print(cur_comp)

        # cur_DMEs <- DMEs %>% filter(DMEs[[group.by]] == cur_comp)
        cur_DMEs <- subset(DMEs, DMEs[[group.by]] == cur_comp)
        # provide title
        cur_title <- cur_comp
        #
        p <- PlotLollipop(modules, cur_DMEs, pvalue, avg_log2FC = 'avg_log2FC')

        p <- p + ggtitle(cur_title) + NoLegend() +  ggforestplot::geom_stripes(aes(y=module), inherit.aes=FALSE, data=cur_DMEs)

        plot_list[[cur_comp]] <- p

      }

  } else if (missing(group.by) && !missing(comparison)) { # Using && instead of & ensures that the second condition is only evaluated if the first condition is TRUE.

    stop('The group.by column is not provided in the DMEs data, and comparison cannot be found.')

  } else if (!missing(group.by) && missing(comparison)) { # Using && instead of & ensures that the second condition is only evaluated if the first condition is TRUE.

    if (!(group.by %in% names(DMEs))) {
    stop('The group.by column is not found in the DMEs data.')
    }

    comparisons <- unique(DMEs[[group.by]])

    plot_list <- list()

    for(cur_comp in comparisons){

        print(cur_comp)

        # cur_DMEs <- DMEs %>% filter(DMEs[[group.by]] == cur_comp)
        cur_DMEs <- subset(DMEs, DMEs[[group.by]] == cur_comp)
        # provide title
        cur_title <- cur_comp

        # set plotting attributes for shape
        p <- PlotLollipop(modules, cur_DMEs, pvalue, avg_log2FC = 'avg_log2FC')

        p <- p + ggtitle(cur_title) + NoLegend() +  ggforestplot::geom_stripes(aes(y=module), inherit.aes=FALSE, data=cur_DMEs)

        plot_list[[cur_comp]] <- p

      }

  } else{
    # this is for the condition: missing(group.by) && missing(comparison)
    print('Please be aware comparison group/groups are not provided, which may casue an ERROR. PlotDMEsLollipop function will automatically assume all values are within the same group.')

    plot_list <- list()

    cur_DMEs <- DMEs

    # set plotting attributes for shape
    p <- PlotLollipop(modules, cur_DMEs, pvalue, avg_log2FC = 'avg_log2FC')

    p <- p + NoLegend() +  ggforestplot::geom_stripes(aes(y=module), inherit.aes=FALSE, data=cur_DMEs)

    plot_list <- p

  }

  return(plot_list)
}




#' PlotLollipop
#'
#' An Internal function for PlotDMEsLollipop to Plotting function for the results of FindDMEs and FindAllDMEs
#'
#' @param seurat_obj A Seurat object
#' @param DMEs dataframe output from FindDMEs or FindAllDMEs
#' @param group.by the column name of the selected comparison in the DMEs dataframe
#' @param comparison character vector or a list of character vectors containing the comparison in the DMEs dataframe
#' @param pvalue p_value or fdr used for the comparison in the DMEs dataframe
#' @keywords scRNA-seq
#' @return A ggplot object
#' @examples
#' plot_list <- PlotLollipop(modules, DMEs, cur_title= c("Group1_vs_Group2"), pvalue, avg_log2FC = 'avg_log2FC')
PlotLollipop <- function(
  modules,
  cur_DMEs,
  pvalue,
  avg_log2FC = 'avg_log2FC'
){
#
    # cur_DMEs <- DMEs

    # set plotting attributes for shape
    cur_DMEs$shape <- ifelse(cur_DMEs[[pvalue]] < 0.05, 21, 4) # 21 cicle; 4 X
    cur_DMEs <- cur_DMEs %>% dplyr::arrange(avg_log2FC, descending=TRUE)
    cur_DMEs$module <- factor(as.character(cur_DMEs$module), levels=as.character(cur_DMEs$module))

    # add number of genes per module
    n_genes <- table(modules$module)
    cur_DMEs$n_genes <- as.numeric(n_genes[as.character(cur_DMEs$module)])

    mod_colors <- dplyr::select(modules, c(module, color)) %>% dplyr::distinct()
    cp <- mod_colors$color; names(cp) <- mod_colors$module

    p <- cur_DMEs %>%
    ggplot(aes(y=module, x=avg_log2FC, size=log(n_genes), color=module)) +
    geom_vline(xintercept=0, color='black') +
    geom_segment(aes(y=module, yend=module, x=0, xend=avg_log2FC), linewidth=0.5, alpha=0.3) +
    geom_point() +
    geom_point(shape=cur_DMEs$shape, color='black', fill=NA) +
    scale_color_manual(values=cp, guide='none') +
    ylab('') +
    xlab(bquote("Avg. log"[2]~"(Fold Change)")) +
    theme(
        axis.line.y = element_blank(),
        axis.ticks.y = element_blank(),
        plot.title = element_text(hjust=0.5, face='plain', size=10)
    )

    return(p)

}





#' ModuleTopologyHeatmap
#'
#' Plots a heatmap of the co-expression network topology of a given module.
#'
#' @return ggplot object containing the ModuleTopologyHeatmap
#'
#' @param seurat_obj A Seurat object
#' @param mod the name of the co-expression module to plot
#' @param matrix specify which matrix to plot, use 'TOM' (topological overlap matrix) or 'Cor' (correlation matrix), or pass a custom square matrix where the rownames and colnames match the genes in this module
#' @param matrix_name name of the matrix plotted that will be used as the label in the plot legend
#' @param order_by order genes in this module by 'kME' (default) or by 'degree' (sum of all connections to this gene in the co-expression network)
#' @param high_color color used for high values in the heatmap, default is the module's unique color
#' @param low_color color used for low values in the heatmap, default is 'white'
#' @param raster logical indicating whether or not to rasterise the plot
#' @param raster_dpi dpi used for a rasterised plot
#' @param plot_max maximum value to plot on the heatmap, can pass a numeric value or a string indicating the quantile ('q99' would be the 99th percentile)
#' @param plot_min minimum value to plot on the heatmap, can pass a numeric value or a string indicating the quantile ('q1' would be the 1st percentile)
#' @param return_genes logical indicating whether or not to return 
#' @param genes_order a character vector of genes to plot in this specific order, this option will override the order_by parameter
#' @param TOM_use  The name of the hdWGCNA experiment containing the TOM that will be used for plotting
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @details
#' ModuleTopologyHeatmap generates a triangular heatmap plot showing the network "topology" of a 
#' specific co-expression module. Each cell in the heatmap represents a gene-gene pair, and the 
#' the heatmap is colored by the strength of the connection between these two genes. By default 
#' the genes in this heatmap are ordered in both the rows and the columns based on their importance 
#' in the module, ranked either by eigengene-based connectivity (kME) or by network degree.  
#'
#' @import Seurat
#' @export
ModuleTopologyHeatmap <- function(
    seurat_obj,
    mod,
    matrix = 'TOM',
    matrix_name = NULL,
    order_by = 'kME',
    high_color = NULL,
    low_color = 'white',
    raster=TRUE,
    raster_dpi=200,
    plot_max = 'q99',
    plot_min = 0,
    return_genes=FALSE,
    genes_order = NULL,
    TOM_use = NULL,
    wgcna_name=NULL,
    ... # pass to WGCNA::adjacency
){

    if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
    CheckWGCNAName(seurat_obj, wgcna_name)
    if(is.null(TOM_use)){TOM_use <- wgcna_name}

    # is ggrastr available?
    raster_avail <- require('ggrastr')
    if(!raster_avail){raster <- FALSE}

    # get the modules table
    modules <- GetModules(seurat_obj, wgcna_name)
    cur_genes <- subset(modules, module == mod) %>% .$gene_name
    hub_df <- GetHubGenes(seurat_obj, mods = mod, n_hubs=Inf, wgcna_name=wgcna_name)

    if(is.null(high_color)){
        high_color <- subset(modules, module == mod) %>% .$color %>% unique
    }

    # get the degree table, and subset for this module 
    degrees <- GetDegrees(seurat_obj, wgcna_name) %>% 
        subset(module == mod)

    if(!is.null(genes_order)){
        genes_order <- genes_order[genes_order %in% cur_genes]
        cur_genes <- genes_order
    }

    # get the Matrix to plot:
    if(any(class(matrix) == 'matrix')){
        mat <- matrix 

        if(is.null(matrix_name)){
            stop("Must provide matrix_name if you supply a matrix (Instead of 'TOM' or 'Cor') to the matrix argument.")
        }
        
        # check that it's a square matrix.
        if(! ncol(mat) == nrow(mat)){
            stop("Invalid matrix. Must be a square matrix.")
        }

        # check that the rownames and colnames are all in the Seurat obj 
        if(! all(rownames(mat) %in% rownames(seurat_obj) & all(colnames(mat) %in% rownames(seurat_obj)))){
            stop("Invalid matrix. rownames and colnames must be found in the rownames(seurat_obj).")
        }

    } else if(matrix == 'TOM'){
        # Get the TOM
        mat <- GetTOM(seurat_obj, TOM_use)
        matrix_name <- matrix
    } else if(matrix == 'Cor'){
        # get the expression matrix:
        datExpr <- as.matrix(GetDatExpr(seurat_obj, wgcna_name))

        # calculate the unsigned adjacency matrix
        mat <- WGCNA::adjacency(
            datExpr, 
            power=1,
            ...
        )
        matrix_name <- matrix
    } else{
        stop("Invalid selection for matrix. Must choose 'TOM' or 'Cor', or provide a square matrix.")
    }

    # only keep the genes that are present in the matrix:
    cur_genes <- cur_genes[cur_genes %in% colnames(mat)]

    # subset tables to match the cur genes:
    hub_df <- hub_df %>% subset(gene_name %in% cur_genes)
    degrees <- degrees %>% subset(gene_name %in% cur_genes)

    # select the method to order genes:
    if(order_by == 'degree'){
        cur_genes <- as.character(degrees$gene_name)
    } else if(order_by == 'kME'){
        cur_genes <- as.character(hub_df$gene_name)
    }

    # format the matrix for plotting
    tmp <- mat[cur_genes,cur_genes]
    tmp[lower.tri(tmp)] <- 0
    plot_df <- reshape2::melt(tmp) %>% subset(Var1 != Var2)

    # set the order for the genes
    plot_df$Var1 <- factor(plot_df$Var1, levels = cur_genes)
    plot_df$Var2 <- factor(plot_df$Var2, levels = cur_genes)
    
    # maximum plot values
    if(class(plot_max) == 'character'){
        maxquant <- as.numeric(gsub('q', '', plot_max ))/100
        plot_max <- as.numeric(quantile(plot_df$value, maxquant))
    } 
    plot_df$value <- ifelse(plot_df$value > plot_max, plot_max, plot_df$value)

    # minimum plot values
    if(class(plot_min) == 'character'){
        minquant <- as.numeric(gsub('q', '', plot_min ))/100
        plot_min <- as.numeric(quantile(plot_df$value, minquant))
    } 
    plot_df$value <- ifelse(plot_df$value < plot_min, plot_min, plot_df$value)

    # assemble the ggplot
    p <- plot_df %>%
        ggplot(aes(x=Var1, y=Var2, fill=value))

    # rasterise?
    if(raster){
        p <- p + ggrastr::rasterise(geom_tile(), dpi=raster_dpi)
    } else{
        p <- p + geom_tile()
    }

    # set color 
    # should I allow for gradient2?
    p <- p +    
        scale_fill_gradient(low=low_color, high=high_color, limits=c(plot_min, plot_max))

    # theme 
    p <- p + theme(
        axis.line.x = element_blank(),
        axis.line.y = element_blank(),
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank(),
        plot.title = element_text(hjust=0.5)
    ) + coord_equal() + ylab('') + xlab('') 

    # default plot title:
    p <- p + labs(fill = matrix_name)

    if(return_genes){
        outs <- list()
        outs[['plot']] <- p 
        outs[['genes']] <- as.character(cur_genes)
        outs[['plot_max']] <- plot_max 
        outs[['plot_min']] <- plot_min
        return(outs)
    }

    # return plot 
    p

}

#' ModuleTopologyBarplot
#'
#' Plots a ranked barplot of genes in a co-expression module by intramodular connectivity
#'
#' @return ggplot object containing the ModuleTopologyBarplot
#'
#' @param seurat_obj A Seurat object
#' @param mod the name of the co-expression module to plot
#' @param features specify the features to use in the barplot, 'kME' or 'degree' or 'weighted_degree' (degree scaled to 0 or 1)
#' @param plot_color color used for the bar plot, default is the module's unique color
#' @param alpha logical indicating whether or not to add opacity to the barplot based on the strength (kME or degree) 
#' @param genes_order a character vector of genes to plot in this specific order, this option will override the order_by parameter
#' @param return_genes logical indicating whether or not to return 
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @details
#' ModuleTopologyBarplot generates a barplot showing the intramodular connectivity of each gene 
#' in a specific co-expression module. Each bar in this plot represents a single gene, and they are
#' ranked based on the strength of their connections within that particular module. A custom gene 
#' ordering can be supplied, which is helpful when comparing the module topologies side by side with 
#' more than one dataset.
#' 
#' @import Seurat
#' @export
ModuleTopologyBarplot <- function(
    seurat_obj,
    mod,
    features = 'kME', # or 'degree', or 'weighted_degree',
    plot_color = NULL, # to use the module color
    alpha=TRUE,
    genes_order = NULL,
    return_genes = FALSE,
    wgcna_name=NULL
){

    if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
    CheckWGCNAName(seurat_obj, wgcna_name)

    # get the modules table
    modules <- GetModules(seurat_obj, wgcna_name)
    cur_genes <- subset(modules, module == mod) %>% .$gene_name
    hub_df <- GetHubGenes(seurat_obj, mods = mod, n_hubs=Inf, wgcna_name=wgcna_name)
    
    if(is.null(plot_color)){
        plot_color <- subset(modules, module == mod) %>% .$color %>% unique
    }

    # get the degree table, and subset for this module 
    degrees <- GetDegrees(seurat_obj, wgcna_name) %>% 
        subset(module == mod)

    if(features == 'kME'){
        plot_df <- hub_df %>% dplyr::rename(value = kME)
        label <- 'kME'
        plot_limits <- c(-1, 1)
    } else if(features == 'degree'){
        plot_df <- degrees %>% dplyr::rename(value = degree)
        label <- 'Degree'
        plot_limits <- c(0, max(plot_df$value))
    } else if(features == 'weighted_degree'){
        plot_df <- degrees %>% dplyr::rename(value = weighted_degree)
        label <- 'Degree'
        plot_limits <- c(0,1)
    } else{
        stop("Invalid selection for features. Must select 'kME' or 'degree'")
    }

    # set factor levels:
    if(!is.null(genes_order)){
        genes_order <- genes_order[genes_order %in% plot_df$gene_name]
        plot_df <- subset(plot_df, gene_name %in% genes_order)
        plot_df$gene_name <- factor(as.character(plot_df$gene_name), levels=genes_order)
        plot_df <- plot_df %>% dplyr::arrange(gene_name)
    } else{
         plot_df$gene_name <- factor(as.character(plot_df$gene_name), levels=as.character(plot_df$gene_name))
    }

    p <- plot_df %>% 
        ggplot(aes(x=gene_name, y=value)) 

    if(alpha){
        p <- p + 
            geom_bar(aes(alpha=value), stat='identity', width=1, fill=plot_color) 
    } else{
        p <- p + geom_bar(stat='identity', width=1, fill=plot_color) 
    }

    p <- p + 
        ylab(label) + 
        scale_y_continuous(expand = c(0, 0), limits = plot_limits) +
        theme(
            axis.ticks.x = element_blank(),
            axis.text.x = element_blank(),
            axis.line.x = element_blank(),
            plot.title = element_text(hjust=0.5),
            axis.title.x = element_blank()
        )
    
    if(return_genes){
        return(list(p, as.character(plot_df$gene_name)))
    }

    p

}


#' PlotModulePreservationLollipop
#'
#' Plots a ranked lollipop plot of co-expression modules based on the results of module preservation analysis.
#'
#' @return ggplot object containing the PlotModulePreservationLollipop
#'
#' @param seurat_obj A Seurat object
#' @param name The name to give the module preservation analysis.
#' @param features The name of the module preservation features to plot. 
#' @param fdr logical indicating whether or not to plot FDR-corrected p-values
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @details
#' PlotModulePreservationLollipop generates a lollipop plot showing module preservation results. If the module
#' preservation test was performed using the WGCNA method, the statistic that will be shown is the Z-summary 
#' preservation statistic. If the analysis was performed using NetRep, then the statistic that will be shown is 
#' the FDR corrected averaged p-values from the module preservation permutation test.
#' 
#' @import Seurat
#' @export
PlotModulePreservationLollipop <- function(
    seurat_obj,
    name,
    features = NULL,
    fdr = TRUE,
    wgcna_name = NULL
){

    if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
    CheckWGCNAName(seurat_obj, wgcna_name)

    # get the module info
    modules <- GetModules(seurat_obj, wgcna_name)
    mod_colors <- modules %>% dplyr::select(module, color) %>% dplyr::distinct()
    mod_cp <- mod_colors$color; names(mod_cp) <- mod_colors$module

    # get the module preservation stats:
    mod_pres <- GetModulePreservation(seurat_obj, name, wgcna_name)

    # was this made with NetRep or WGCNA?
    if(all(names(mod_pres) %in% c('nulls', 'observed', 'p.values', 'nVarsPresent', 'propVarsPresent', 'totalSize', 'alternative'))){
        netrep_used <- TRUE
    } else{
        netrep_used <- FALSE
    }

    if(netrep_used){

        if(is.null(features)){features <- 'average'}

        # check for valid features
        if(! features %in% c('average', colnames(mod_pres$p.value))){
          stop(paste0('Invalid selection for features. Valid selections are: ', paste(c('average', colnames(mod_pres$p.value)), collapse=', ')))
        }

        # pvals
        plot_df <- reshape2::melt(mod_pres$p.value)
        plot_df$Var2 <- paste0('pval.', plot_df$Var2)
        plot_df$type <- 'pval'
        plot_df1 <- plot_df

        # compute FDR:
        fdrs <- p.adjust(plot_df1$value,'fdr')
        plot_df_fdrs <- plot_df1 
        plot_df_fdrs$value <- fdrs
        plot_df_fdrs$type <- 'fdr'

        # observed stats
        plot_df <- reshape2::melt(mod_pres$observed)
        plot_df$type <- 'observation'
        plot_df2 <- plot_df

        plot_df <- rbind(plot_df1,  plot_df_fdrs, plot_df2) %>% 
            dplyr::rename(module = Var1, stat = Var2)

        # compute the average p-vals:
        if(features == 'average'){
          plot_title <- "Summary"
          if(fdr){
            plot_df <- plot_df %>% 
                    subset(type == 'fdr') %>%
                    group_by(module) %>% 
                    summarise(value = mean(value)) 
            label <- bquote("-log"[10]~"(Avg. FDR)")
          } else{
            plot_df <- plot_df %>% 
                    subset(type == 'pval') %>%
                    group_by(module) %>% 
                    summarise(value = mean(value)) 
            label <- bquote("-log"[10]~"(Avg. P-value)")
          }
        } else{
          plot_title <- features
          plot_df <- plot_df %>% subset(stat == paste0('pval.',features))
          if(fdr){
            plot_df <- plot_df %>% subset(type == 'fdr')
            label <- bquote("-log"[10]~"(FDR)")
          } else{
            plot_df <- plot_df %>% subset(type == 'pval')
            label <- bquote("-log"[10]~"(P-value)")
          }
        }

        # add info about the module size (number of genes)
        mod_sizes <- mod_pres$nVarsPresent 
        plot_df$mod_size <- mod_sizes[plot_df$module]

        # order by p-val 
        plot_df <- plot_df %>% dplyr::arrange(desc(value))
        plot_df$module <- factor(as.character(plot_df$module), levels=as.character(plot_df$module))

        p <- plot_df %>%
        ggplot(aes(y=module, x=-log10(value), size= mod_size, color=module, fill=module)) + 
            geom_rect(
                data = plot_df[1,],
                aes(xmin=-Inf, ymax=Inf, ymin=-Inf, xmax=-log10(0.05)), fill='grey90', alpha=0.8, color=NA) +
        #geom_vline(xintercept=2, linewidth=0.5, linetype='dashed') +
        geom_segment(aes(y=module, yend=module, x=0, xend=-log10(value)), size=0.5, alpha=0.5) +
        geom_point(shape=21, color='black') +
        scale_color_manual(values=mod_cp, guide='none') +
        scale_fill_manual(values=mod_cp, guide='none') +
        ylab('') + labs(size=bquote("N"[genes])) +
        xlab(label) +
        theme(
            panel.border = element_rect(size=1, color='black', fill=NA),
            axis.line.y = element_blank(),
            axis.line.x = element_blank(),
            plot.title = element_text(hjust=0.5, face='bold')
        ) + ggtitle(plot_title)

    } else{

        mod_pres <- mod_pres$Z

        if(is.null(features)){
          features <- 'Zsummary.pres'
        }

        # check for valid features
        if(! features %in% colnames(mod_pres)){
          stop(paste0('Invalid selection for features. Valid selections are: ', paste(c('average', colnames(mod_pres), collapse=', '))))
        }

        plot_df <- mod_pres[,c('moduleSize', features)]
        colnames(plot_df)[2] <- 'value'
        plot_df$module <- rownames(plot_df)
        plot_df <- plot_df %>%subset(! module %in% c('gold', 'grey'))
        plot_df <-plot_df %>% dplyr::arrange(value, descending=TRUE)
        plot_df$module <- factor(as.character(plot_df$module), levels=as.character(plot_df$module))

        p <- plot_df %>%
        ggplot(aes(y=module, x=value, size= moduleSize, color=module, fill=module)) + 
        geom_rect(
                data = plot_df[1,],
                aes(xmin=-Inf, ymax=Inf, ymin=-Inf, xmax=2), fill='grey75', alpha=0.8, color=NA) +
            geom_rect(
            data=plot_df[1,],
            aes(ymin=-Inf, ymax=Inf, xmin=2, xmax=10), fill='grey92', alpha=0.8, color=NA) + 
        geom_segment(aes(y=module, yend=module, x=0, xend=value), size=0.5, alpha=0.5) +
        geom_point(shape=21, color='black') +
        scale_color_manual(values=mod_cp, guide='none') +
        scale_fill_manual(values=mod_cp, guide='none') +
        ylab('') + xlab('') +
        ggtitle(features) +
        theme(
            panel.border = element_rect(size=1, color='black', fill=NA),
            axis.line.y = element_blank(),
            axis.line.x = element_blank(),
            plot.title = element_text(hjust=0.5, face='bold')
        ) + RotatedAxis()
    }

    # return the plot
    p

}



#' ModuleRadarPlot
#'
#' Plots the expression level (module eigengene) of each co-expression module 
#' for different groups as a radar plot.
#'
#' @return ggplot object containing the ModuleRadarPlot
#'
#' @param seurat_obj A Seurat object
#' @param group.by the column name of the selected comparison in the DMEs dataframe
#' @param barcodes A list of barcodes from colnames(seurat_obj) which will be used to subset the data before plotting. 
#' @param combine logical indicating whether or not to combine plots using patchwork
#' @param ncol The number of columns for the combined plot if patchwork is being used
#' @param wgcna_name The name of the hdWGCNA experiment in the seurat_obj@misc slot
#' @param ... additional parameters for ggradar
#' @details
#' ModuleRadarPlot visualizes the expression level (module eigengene, ME) of each co-expression module for different 
#' groups in a radial coordinate system. The ME for each module is averaged for each group, scaled, and then 
#' is plotted radially. The resulting plots help us to interpret which cell groups (clusters, cell types, etc) are 
#' expressing each module.
#' 
#' @import Seurat
#' @export
ModuleRadarPlot <- function(
  seurat_obj,
  group.by = NULL,
  barcodes = NULL,
  combine = TRUE,
  ncol=4, 
  wgcna_name = NULL,
  fill=TRUE,
  draw.points=FALSE,
  ... # additional params for ggradar
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
  CheckWGCNAName(seurat_obj, wgcna_name)

  if (!require("ggradar")) {
    print('Missing package: ggradar')
    print('Installing package: ggradar')
    devtools::install_github("ricardo-bion/ggradar", dependencies = TRUE)
  }

  # get seurat metadata
  meta <- seurat_obj@meta.data 

  # use idents as the group
  if(is.null(group.by)){
    cell_grouping <- Idents(seurat_obj)
  } else{
    cell_grouping <- seurat_obj@meta.data[,group.by]
    names(cell_grouping) <- colnames(seurat_obj)
  }

  if(is.factor(cell_grouping)){
    group_order <- levels(cell_grouping)
  } else{
    group_order <- unique(cell_grouping)
  }

  # get the module info
  modules <- GetModules(seurat_obj, wgcna_name)
  mod_colors <- modules %>% dplyr::select(c(module, color)) %>% dplyr::distinct()
  mods <- levels(modules$module); mods <- mods[mods != 'grey']

  # get the MEs
  MEs <- GetMEs(seurat_obj)
  MEs <- MEs[,colnames(MEs) != 'grey']

  # are we subsetting?
  if(!is.null(barcodes)){
    if(!(all(barcodes %in% colnames(seurat_obj)))){
      stop('Invalid selection for barcodes, some are not found in the colnames(seurat_obj)')
    }

    # subset:
    MEs <- MEs[barcodes,]
    meta <- meta[barcodes,]
    cell_grouping <- cell_grouping[barcodes]
  }

  MEs$cluster <- cell_grouping
  clusters <- as.character(unique(cell_grouping))

  # calculate the mean ME for each group
  plot_df <- MEs %>% 
    dplyr::group_by(cluster) %>% 
    dplyr::summarise_all(mean) %>%
    as.data.frame() 

  # re-format the dataframe
  rownames(plot_df) <- plot_df$cluster
  plot_df <- dplyr::select(plot_df, -cluster) 
  plot_df <- t(plot_df) %>% as.data.frame()
  plot_df[plot_df < 0] <- 0 
  plot_df$group <- rownames(plot_df)
  plot_df <- plot_df[,c('group', clusters)]
  print(head(plot_df))

  # set module factor levels
  plot_df$group <- factor(as.character(plot_df$group), levels=mods)
  plot_df <- plot_df %>% dplyr::arrange(group) %>% as.data.frame()
  colnames(plot_df) <- c('group', clusters)

  # set the group factor levels
  plot_df <- plot_df[,c('group', group_order)]

  # make the radar plots for each module
  plot_list <- list()
  for(i in 1:nrow(plot_df)){
    cur_mod <- as.character(plot_df[i,'group'])
    cur_color <- subset(mod_colors, module == cur_mod) %>% .$color
    plot_list[[cur_mod]] <- ggradar::ggradar(
      plot_df[i,], group.colours=cur_color,
      draw.points=draw.points,
      fill=fill, ...
      ) + 
      Seurat::NoLegend() + 
      ggtitle(cur_mod) + 
      theme(
        plot.title = element_text(face='bold', hjust=0.5)
      )
  }

  # combine plots with patchwork ?
  if(combine){
    patch <- wrap_plots(plot_list, ncol)
    return(patch)
  } else{
    return(plot_list)
  }

}
smorabit/scWGCNA documentation built on April 5, 2025, 3:57 p.m.