# Several different methods of variable gene selection, all based on outliers in mean-versus-sd plots.
# You can give either x and y cutoffs for these plots or the number of genes to select -- not both.
#' @export
var_gene_select = function( dge, results_path, test_mode = F, 
                            prop_genes_to_select = NULL,
                            num_genes_to_select = NULL,
                            excess_var_cutoff = NULL,
                            log_expr_cutoff = NULL,
                            method = "seurat" ){
  atat( method %in% c( "seurat", "spline", "knn" ) )
  dir.create.nice( results_path )
  # # Parse input
  if( !is.null( num_genes_to_select ) & is.null( prop_genes_to_select ) ){
    prop_genes_to_select = num_genes_to_select / dim(dge@data)[1]
  cutoffs_given = !(is.null(excess_var_cutoff) & is.null(log_expr_cutoff))
  prop_was_given  = ! ( is.null(prop_genes_to_select) || is.na(prop_genes_to_select) )
  assertthat::assert_that( xor(prop_was_given, cutoffs_given) )
    prop_was_given = T
    prop_genes_to_select = 0.015
  # # Need to do this to initialize @mean.var even if not going to use those cutoffs
  if(is.null(log_expr_cutoff)){log_expr_cutoff = 0}
  if(is.null(excess_var_cutoff)){excess_var_cutoff = 0}
  dge = Seurat::MeanVarPlot( dge, x.low.cutoff = log_expr_cutoff, y.cutoff = excess_var_cutoff)
  if(method == "seurat"){
    # Variable gene selection method 1 -- Seurat built-in
    dge@mean.var$avg_log_exp = dge@mean.var$data.x
    dge@mean.var$variability = dge@mean.var$data.y
    dge@mean.var$excess_variability = dge@mean.var$data.norm.y
  if(method == "spline"){
    # Variable gene selection method 2 -- residuals from spline curve
    sd_log_exp  = unlist( apply( FUN = sd  , dge@data, MARGIN = 1 ) )
    avg_log_exp = unlist( apply( FUN = mean, dge@data, MARGIN = 1 ) )
    gam_data = data.frame( y = sd_log_exp, x = avg_log_exp )
    s = mgcv:::s
    mean_var_reg = mgcv::gam( data = gam_data, formula = y~s(x), family = mgcv::scat() )
    excess_variability = standardize( sd_log_exp - mean_var_reg$fitted.values )
    dge@mean.var$avg_log_exp = avg_log_exp
    dge@mean.var$variability = sd_log_exp
    dge@mean.var$excess_variability = excess_variability
  if(method == "knn"){
    # # Variable gene selection method 3 -- knn-based outlier detection
    avg_log_exp = unlist( apply( FUN = mean, dge@data, MARGIN = 1 ) )
    sd_log_exp  = unlist( apply( FUN = sd  , dge@data, MARGIN = 1 ) )
    p = outlier_labeled_scatterplot( data.frame( avg_log_exp = avg_log_exp,
                                                 sd_log_exp = sd_log_exp, 
                                                 gene = rownames( dge@data ) ) )
    dge@mean.var$avg_log_exp = avg_log_exp
    dge@mean.var$variability = sd_log_exp
    temp = p$plot_env$data$dist_to_nn
    dge@mean.var$excess_variability = standardize( temp )
    dge@var.genes = top_n_preserve_rownames( dge@mean.var, 
                                             n = round(prop_genes_to_select * nrow( dge@mean.var ) ),
                                             wt = excess_variability ) %>% rownames
  } else {
    dge@var.genes = subset( dge@mean.var, 
                            excess_variability > excess_var_cutoff & 
                              avg_log_exp > log_expr_cutoff ) %>% rownames
  dge@mean.var$included = rownames( dge@mean.var ) %in% dge@var.genes
  # # Save mean_var_plots 
  p_reg = ggplot(dge@mean.var) + ggtitle("Mean versus coef. var for all genes") +
    geom_point(aes(avg_log_exp, variability, colour = included), alpha = 0.5) + 
    geom_vline(aes(xintercept = log_expr_cutoff))
  ggsave(file.path(results_path, "MeanVarPlot_reg.pdf"), p_reg)
  p_res = ggplot(dge@mean.var) + ggtitle("Mean versus excess var across genes") +
    geom_point(aes(avg_log_exp, excess_variability, colour = included), alpha = 0.5) + 
    geom_vline(aes(xintercept = log_expr_cutoff)) + 
    geom_hline(aes(yintercept = excess_var_cutoff))
  ggsave(file.path(results_path, "MeanVarPlot_res.pdf"), p_res)
  # # Save variable genes and parameters
  vgsrp = file.path( results_path, "var_gene_select" )
  dir.create.nice( vgsrp )
  cell_markers = get_rene_markers()$marker %>% harmonize_species(dge)
  variable_cell_markers = intersect( cell_markers, as.character( dge@var.genes ) )
  variable_cell_markers = c(paste0(length(variable_cell_markers), "total"), variable_cell_markers)
  text2file(file.path(vgsrp, "markers_among_variable_genes.txt"), variable_cell_markers)
  totalstring = paste(length(as.character(dge@var.genes)), "total")
  var_genes = c(totalstring, as.character(dge@var.genes))
  text2file(file.path(vgsrp, "variable_genes.txt"), var_genes)
  if( is.null( num_genes_to_select  ) ) { num_genes_to_select  = "NULL" }
  if( is.null( prop_genes_to_select ) ) { prop_genes_to_select = "NULL" }
  if( is.null( excess_var_cutoff    ) ) { excess_var_cutoff    = "NULL" }
  if( is.null( log_expr_cutoff      ) ) { log_expr_cutoff      = "NULL" }
  vsp        = c( prop_genes_to_select,   num_genes_to_select,   excess_var_cutoff,   log_expr_cutoff)
  names(vsp) = c("prop_genes_to_select", "num_genes_to_select", "excess_var_cutoff", "log_expr_cutoff")
  text2file(file.path(vgsrp, "var_gene_selection_params.txt"), collapse_by_name(vsp))

#' @export
save_complexity_plot = function(dge, result_path){
  dir.create.nice( file.path( result_path, "QC" ) )
  f = file.path(result_path, "QC/complexity.pdf")
  p = ggplot( data.frame( complexity = colSums( dge@raw.data > 0 ) ) ) + 
    ggtitle("Complexity") + geom_histogram(aes(x = complexity)) 
  ggsave(f, p)
  f = file.path(result_path, "QC/UMIs_by_cell.pdf")
  p = ggplot( data.frame( UMIs_by_cell = colSums( dge@raw.data ) ) ) + 
    ggtitle("UMIs per cell") + geom_histogram(aes(x = log10(UMIs_by_cell)))
  ggsave(f, p)

## ------------------------------------------------------------------------

#' Decide if a gene is a ribosomal subunit or a pre-rRNA transcript.
#' @export
is_rp = function( genes ){
  genes %<>% toupper
  p1 = (genes == "RN45S")
  p2 = substring(genes, 1, 3) %in% c("RPS", "RPL")

#' Decide if genes are mitochondrial. Currently uses a crude "grep mt".
#' @export
is_mt = function( genes ){
  genes %<>% toupper
  p = substring(genes, 1, 3) %in% c("MT-", "MT.")

#' Quantify ribosomal transcript content cell by cell.
#' @export
add_rp_percentage = function(dge) add_rp_mt_percentage(dge )

#' Quantify ribosomal and mitochondrial transcript content cell by cell.
#' @export
add_rp_mt_percentage = function( dge ){
  nUMI = dge %>% FetchData("nUMI") %>% extract2(1) 
  sum_umis = function(predicate){
    raw = dge@raw.data[rownames(dge@data), colnames(dge@data)]
    genes = rownames( raw )
    x = dge@raw.data[ predicate( genes ), ] %>% colSums
    atat( all( x < nUMI ) ) 
  nUMI_rp = sum_umis(predicate = is_rp)
  nUMI_mt = sum_umis(predicate = is_mt)
  to_add = data.frame( nUMI_mt = nUMI_mt,
                       nUMI_rp = nUMI_rp, 
                       mt_nUMI_pct   = nUMI_mt / nUMI,
                       ribo_nUMI_pct = nUMI_rp / nUMI, 
                       row.names = dge@cell.names )
  dge %<>% AddMetaData( metadata = to_add )

#' Quantify cell-cycle activity by phase.
#' The function `add_cc_score` takes a Seurat object and adds cell cycle scores to the metadata (`object@data.info`). The columns are quantitative scores for each of the five cell-cycle phases. 
#' @export
add_cc_score = function(dge, method = "average"){
  # Get cell cycle genes and expression data
  macosko_cell_cycle_genes = get_macosko_cc_genes()
  phases = colnames( macosko_cell_cycle_genes )
  raw_dge = as.matrix( dge@data ) # calculate scores on log1p-scale data
  # # To reconcile the data with the predefined list:
  # # - get human orthologs when appropriate
  # # - get both Capital and UPPERCASE
  # # - intersect gene list with available genes
  geneset = as.list(phases); names(geneset) = phases
  for( phase in phases ){
    human_ortho = c()
    try(dge %<>% add_maehrlab_metadata("species"))
    if ( !"species" %in% AvailableData( dge ) ){
      warning(paste("No species metadata present. Assuming mouse, but please add species metadata.\n", 
                    "Try `object = add_maerhlab_metadata(object, 'species')`.\n" ) )
    } else if( any( Seurat::FetchData(dge, "species")[[1]] == "human" ) ){
      human_ortho = unlist( lapply( X = macosko_cell_cycle_genes[, phase], 
                                    FUN = get_ortholog, from = "human", to = "mouse" ) )
      human_ortho = human_ortho[!is.na( human_ortho )]
    geneset[[phase]] = union( Capitalize( macosko_cell_cycle_genes[, phase] ), 
                     toupper(    macosko_cell_cycle_genes[, phase] ) )
    geneset[[phase]]  = union( geneset[[phase]], human_ortho )
    geneset[[phase]]  = intersect( geneset[[phase]] , rownames( raw_dge ) )
  # produce a score for each phase
  scores_mat = matrix( 0, nrow = length( phases ), ncol = ncol( raw_dge ) )
  rownames( scores_mat ) = colnames( macosko_cell_cycle_genes )
  colnames( scores_mat ) = colnames( raw_dge )
  for( phase in phases ){
    scores_mat[phase, ] = apply( X = raw_dge[geneset[[phase]],], MARGIN = 2, FUN = mean )
  scores_df = data.frame(t(scores_mat))
  dge = Seurat::AddMetaData(dge, scores_df)
  return( dge )

## ------------------------------------------------------------------------
#' @export
do_dim_red = function( dge, pc.use, ... ){
  dge = Seurat::PCA( dge, do.print = F, pcs.store = max( pc.use ) )
  if( test_mode ){
    pc.use = 1:4
    # # Solve problem downstream where t-SNE can't handle exact duplicates
    if( test_mode ) {
      dge@pca.rot = dge@pca.rot + matrix( 0.1*rnorm( prod( dim( dge@pca.rot ) ) ), nrow = nrow( dge@pca.rot ) )
  dge = Seurat::RunTSNE( dge, dims.use = pc.use, ... )
  return( dge )

## ------------------------------------------------------------------------
#' @export
top_genes_by_pc = function(dge, results_path, test_mode, num_pc = 30){
  if(test_mode){return()} #Doesn't work with small data. Don't care; bypassing.
  dir.create.nice(file.path(results_path, "top_genes_for_each_pc") )
  for(pc in 1:num_pc){
    pdf(file.path(results_path, "top_genes_for_each_pc", paste0("pc", pc, ".pdf")))
    VizPCA(dge, pcs.use = pc)

## ------------------------------------------------------------------------
#' Apply a bagged clustering procedure to a Seurat object using PCA as input.
#' @param dge A Seurat object.
#' @param k Number of clusters. (Some may end up empty.)
#' @param num_pc Number of PC's to use from dge. Specify this or feature_names, not both. Default 25. 
#' @param my_seed RNG seed.
#' @param B Number of bootstrap replicates.
#' @param feature_names Fetched via FetchData as input. Specify this or num_pc, not both. Default unused.
#' @param ... Additional args passed to clue::cl_bag.
#' @export
ClusterBag = function(dge, k, num_pc = NULL, feature_names = NULL, my_seed = 100, B = 100, ... ){
  if(!is.null(feature_names) && !is.null(num_pc) ){
    stop("Please specify only one of 'num_pc' and 'feature_names'.")
  if(is.null(feature_names) && is.null(num_pc) ){
    warning("Please specify one of 'num_pc' and 'feature_names'. Using 25 PC's by default.")
    num_pc = 25
  if( !is.null( num_pc ) ){
    X = dge@pca.rot[1:num_pc]
  } else {
    X = FetchData( dge, feature_names )
  cl_obj = clue::cl_bag( X, k = k, B = B, ... ) 
  memberships = as.data.frame( cl_obj$.Data[,] )
  colnames(memberships) = paste0("cl_bag_", 1:ncol(memberships))
  rownames(memberships) = rownames(dge@data.info)
  assignment = apply(memberships, 1, which.max)
  confidence = apply(memberships, 1, max)
  memberships$cl_bag_assignment = paste0("cl_bag_", assignment)
  memberships$cl_bag_confidence = confidence
  dge %<>% AddMetaData( memberships )

#' Wrapper for ClusterBag. Measures stability over a range of model sizes.
#' @export
ClusterBagStability = function( dge, k_max = 15, ... ){
  kvals = 2:k_max
  stability = kvals
  for( k in kvals ){
    dge = ClusterBag( dge, k, ...)
    stability[k - 1] = mean(FetchData(dge, "cl_bag_confidence")[[1]])
  plot(kvals, stability)
  return( data.frame(k = kvals, stability = stability) )

#' A wrapper for Seurat's clustering functions.
#' @param method is either "DBSCAN" or "SNN".
#' @param granularities_as_string should be numbers separated by commas and whitespace.
#'     The number of clusterings performed is the length of (the split and cleaned version of) `granularities_as_string`.
#'     The granularity number is used as the neighborhood radius in DBSCAN or the "resolution" in SNN.
#' @details 
#' For more information on density-based clustering (DBSCAN), look at the KDD-96 paper by Ester, Kriegel, Sander, and Xu.
#' "A density-based algorithm for discovering clusters in large spatial databases with noise."
#' For more information on the SNN-based clustering, look at (the parameter gamma in)
#' "A unified approach to mapping and clustering of bibliometric networks", 
#' Ludo Waltman, Nees Jan van Eck, and Ed C.M. Noyons

# # A postcondition of this wrapper is that cluster 1 is not a legit cluster.
# # It is either empty or the set of "rejects".
#' @export
cluster_wrapper = function(dge, results_path, test_mode, 
                           pc.use = NULL, 
                           method = c("DBSCAN"),
                           granularities_as_string = "6",
                           merge_small = F){
  granularities = as.numeric( trimws( strsplit( granularities_as_string, split = "," )[[1]] ) )
  dir.create.nice( file.path( results_path, method ) )
  for(my_resolution in granularities){
    if(method == "DBSCAN"){
      dge = Seurat::DBClustDimension(dge, reduction.use="tsne", G.use=my_resolution, set.ident = T)
    } else {
      atat( method == "SNN" )
      dge = Seurat::FindClusters(dge, 
                                 pc.use = pc.use, 
                                 resolution = my_resolution, 
                                 print.output = F)
      ident_no_1 = dge@ident %>% as.character %>% as.numeric
      ident_no_1[ ident_no_1==1 ] = max( ident_no_1 ) + 1
      ident_no_1 = as.character( ident_no_1 )
      dge = Seurat::SetIdent(dge, ident.use = ident_no_1)
    tsne_colored( dge, file.path(results_path, method), colour = "ident",
                  fig_name = paste0("res=", my_resolution, ".pdf"))

  if(test_mode & 1==length(levels(dge@ident))){
    warning("In test mode, got 1 cluster. 
            Randomly splitting into 2 clusters to facilitate debugging of differential expression code.")
    dge = Seurat::SetIdent(dge, ident.use = sample(1:2, size = length(dge@ident), replace = T) %>% as.character)
  cluster_summary(dge, results_path)

# # This records summary information for each of the clusters (currently just size). 
# # It also makes an extra copy in `named_clusters.txt` that can be hand-edited to manually name the clusters.
#' @export
cluster_summary = function(dge, results_path){
  clus_summ = data.frame(table(dge@ident))
  names(clus_summ) = c("Cluster", "Num Cells")
  write.table(x = clus_summ, 
              file = file.path(results_path, "cluster_summary.txt"),
              sep = "\t", quote = F, row.names = F, col.names = T)
  clus_summ$cluster_name = "insert_name_here"
  clus_summ$color = rainbow(length(clus_summ[[1]]))
  write.table(x = clus_summ, 
              file = file.path(results_path, "named_clusters.txt"), 
              sep = "\t", quote = F, row.names = F, col.names = T)
  return( clus_summ )

# # This helps compare two different analyses of the same cells.
# # You put in the usual Seurat object and results path
# # plus another Seurat object that you want to compare against,
# # and names for each of them.
#' @export
compare_views = function(dge, results_path, comparator_dge, dge_name, comparator_name){
  figname = paste0("embedding=", dge_name, "|colors=", comparator_name, ".pdf")
  dge = Seurat::AddMetaData( dge, metadata = comparator_dge@ident[dge@cell.names] , col.name = "other_id" )
  ggsave(file.path(results_path, figname),
         custom_feature_plot(dge, colour = "other_id"),
         width = 5.5, height = 5)

## ------------------------------------------------------------------------

#' Find genes with expression patterns similar to the genes you've specified.
#' @param `dge` : a Seurat object with field `@scale.data` filled in.
#' @param `markers`: a character vector; giving gene names.
#' @param `n`: integer; number of results to return.
#' @param `anticorr` : allow negatively correlated genes; defaults to `FALSE`.
#' @param `aggregator` : If multiple genes, how to combine their correlations. Default is "sum". For an "and"-like filter, use "min".   
#' Given a Seurat object and a list of gene names, this function returns genes 
#' that are strongly correlated with those markers. 
#' @return character vector.
#' @export
get_similar_genes = function( dge, markers, n, anticorr = F ){
  data.use = dge@scale.data
  if(!all(markers %in% rownames(data.use))){ 
    warning("Some of your markers have no data available. Trying Various CASE Changes.")
    markers = unique( c( markers, toupper(markers), Capitalize( markers ) ) )
  markers = intersect(markers, rownames( data.use) ) 
  correlation = rowSums( data.use %*% t( data.use[markers, , drop = F]) ) 
  correlation = correlation[ setdiff( names( correlation ), markers ) ]
  if( anticorr ){
    similar_genes = names( sort( abs( correlation ), decreasing = T )[ 1:n ] )
  } else {
    similar_genes = names( sort( correlation, decreasing = T )[ 1:n ] )
  return( similar_genes )  

## ------------------------------------------------------------------------
#' Programmatically feed gene-sets to Enrichr and save formatted results.
#' @export
do_enrichr = function( results_path, geneset, geneset_name, 
                       desired_db = c( "KEGG_2016", 
                                       "GO_Biological_Process_2015" ),
                       N_ANNOT_PER_DB = 2 ){
  if(length(geneset) == 0) {
    warning( "List of length zero fed to do_enrichr. Returning NULL.")
  my_rp = file.path( results_path, paste0("annot_", geneset_name) )
  dir.create.nice( my_rp )
  # # Get enrichr results and parse them
  output_table = enrichR::enrichr( genes = geneset, databases = desired_db ) 
  output_table %<>% (base::mapply)(desired_db, FUN = function(X, db) {try({X$database = db}); return(X)}, SIMPLIFY = F)
  output_table %<>% (base::Reduce)( f=rbind )
  output_table %<>% (dplyr::group_by)( database ) %>% (dplyr::top_n)( wt = -P.value, n = N_ANNOT_PER_DB) %>% as.data.frame
  output_table %<>% (dplyr::mutate)( log10_qval = round( log10( Adjusted.P.value ), 1 ) )
  output_table = output_table[c("database", "Term", "Overlap", "log10_qval", "Genes")]
  # # Save raw table
  write.table( x = output_table, 
               file = file.path( my_rp, "raw.txt"  ),
               sep = "\t", quote = F, row.names = T, col.names = T )
  write.table( x = geneset, 
               file = file.path( my_rp, "annotated_genes.txt"  ),
               sep = "\t", quote = F, row.names = F, col.names = F )
  # # Save pretty version
  pretty_table = output_table
  pretty_table$Genes  = NULL
  n_colors = length(unique(pretty_table$database)); 
  color_idx = pretty_table$database %>% factor(levels = desired_db, ordered = T) %>% as.integer
  my_cols = scales::hue_pal()(n_colors)[ color_idx ]
  theme_color_db = gridExtra::ttheme_minimal( core=list( bg_params = list( fill = my_cols ) ) )
  ggsave( gridExtra::tableGrob( d = pretty_table, theme = theme_color_db ), 
          file = file.path( my_rp, "color=database.pdf" ), 
          width = 15, height = N_ANNOT_PER_DB*length(desired_db) / 3, limitsize = F )

## ------------------------------------------------------------------------

#' Helper function. Check if a list of markers is compatible with a given Seurat object 
#' so that genes and cluster assignments are present in both `marker_info` and
#' the Seurat object `dge`.
#' If `desired_cluster_order` is given, `are_compatible` checks that it is 
#' free of duplicates and it is a superset of the identity values occurring in other inputs.
are_compatible = function( dge, marker_info, ident.use ){
  atat( all( c("gene", "cluster") %in% names( marker_info ) ) )
  factor_flag = FALSE
  for( i in seq_along( marker_info ) ) {
    factor_flag = factor_flag || is.factor( marker_info[[i]] )
    marker_info[[i]] = as.character( marker_info[[i]] )
  if( factor_flag ){ warning( "Factor columns detected in marker_info." ) }
  dge_ident =  as.character( FetchData( dge, ident.use )[[1]] )
  gene_compat  = all( marker_info$gene    %in% rownames( dge@data ) )
  ident_compat = all( marker_info$cluster %in% dge_ident )
  if( !gene_compat ){
    warning("marker_info$cluster has ID's not available in Seurat object")
  if( !ident_compat ){
    warning("marker_info$gene has genes not available in Seurat object")
  return( gene_compat && ident_compat )

#' Check if a list of cell types is good to be used to order genes and cells in a heatmap.
#' @details This means:
#' - no duplicates
#' - up to order, should be equal to union of table cluster labels and dge cluster labels
#' - no missing values
fix_cluster_order = function(  dge, marker_info, ident.use, desired_cluster_order = NULL ){
  if( is.null( desired_cluster_order ) ){
    warning( "No cluster order specified. Ordering clusters stupidly." )
    desired_cluster_order = union( Seurat::FetchData( dge, ident.use )[[1]], 
                                   marker_info$cluster )
  atae( typeof( desired_cluster_order ), "character" )
  all_celltypes = union( Seurat::FetchData(dge, ident.use)[[1]], marker_info$cluster )
  missing = setdiff( all_celltypes, desired_cluster_order)
  extra   = setdiff( desired_cluster_order, all_celltypes)
  if( length( missing ) > 0 ){
    warning("Adding missing cell types to `desired_cluster_order` from Seurat object or marker table.")
    desired_cluster_order = c(desired_cluster_order, missing)
  if( length( extra ) > 0 ){
    warning("Removing extraneous cell types from `desired_cluster_order`.")
    desired_cluster_order = intersect( desired_cluster_order, all_celltypes )
  if( anyDuplicated( desired_cluster_order ) ){
    warning("Omitting duplicates in `desired_cluster_order`.")
    desired_cluster_order = unique( desired_cluster_order )
  atat(  !any( is.na( desired_cluster_order ) ) )
  return( desired_cluster_order )

## ------------------------------------------------------------------------
#' Make a heatmap with one column for each cluster in `unique( Seurat::FetchData(dge, ident.use)[[1]])` and 
#' one row for every gene in `genes_in_order`. 
#' @param dge Seurat object
#' @param desired_cluster_order Levels of FetchData(dge, ident.use)[[1]], ordered how you want them to appear.
#' @param ident.use Variable to aggregate by.
#' @param labels "regular" (all labels), "stagger" (all labels, alternating left-right to fit more genes), or "none". 
#' @param aggregator Function to aggregate expression within a group of cells. Try mean or prop_nz.
#' @param normalize "row", "column", "both", or "none". Normalization function gets applied across the axis this specifies.
#' @param norm_fun Function to use for normalization. Try div_by_max or standardize.
#' @param genes_to_label A (small) subset of genes to label. If set, nullifies labels arg. 
#' @param main Title of plot.
#' @param return_type "plot" or "table". If "table", then instead of returning a heatmap, this 
#' returns the underlying matrix of normalized, cluster-aggregated values. If anything else, returns a ggplot.
#' If the cluster's expression values are stored in `x`, then `aggregator(x)` gets (normalized and) plotted.
#' Optional parameter `desired_cluster_order` gets coerced to character. It should be a subset of 
#' `unique(Seurat::FetchData(dge, ident.use))` (no repeats).
#' @export
make_heatmap_for_table = function( dge, genes_in_order, 
                                   desired_cluster_order = NULL, 
                                   ident.use = "ident",
                                   labels = NULL, 
                                   aggregator = mean, 
                                   normalize = "row", 
                                   norm_fun = div_by_max,
                                   genes_to_label = NULL,
                                   main = "Genes aggregated by cluster",
                                   return_type = "plot" ){
  if( !is.null( labels ) && !is.null(genes_to_label) ){
    warning("labels is ignored when genes_to_label is (are?) specified.\n")
    labels = "none"
  # # Set up simple ident variable
    warning("No cell-type ordering given. Using arbitrary ordering.")
    desired_cluster_order = list(FetchData(dge, ident.use)[1, 1])
  marker_info = data.frame( gene = genes_in_order, cluster = desired_cluster_order[[1]] )
  desired_cluster_order = fix_cluster_order( dge, marker_info, ident.use, desired_cluster_order )
  ident = FetchData(dge, ident.use) %>% 
    vectorize_preserving_rownames %>% 
    factor(levels = desired_cluster_order, ordered = T)
  ident = sort(ident)
  # # Sanitize input -- characters for genes, and no duplicate genes.
  genes_in_order = as.character( genes_in_order )
  if( anyDuplicated( genes_in_order ) ){
    warning( "Sorry, can't handle duplicate genes. Removing them." )
    genes_in_order = genes_in_order[ !duplicated( genes_in_order )]
  if( !all( genes_in_order %in% AvailableData(dge) ) ){
    warning( "Some of those markers are not available." )
    genes_in_order = intersect( genes_in_order, AvailableData(dge) )
  # # Get cluster mean expression for each gene and row normalize
  logscale_expression = Seurat::FetchData(dge, vars.all = genes_in_order)[names( ident ), ]
  expression_by_cluster = aggregate_nice( x = logscale_expression, by = list( ident ), FUN = aggregator )
  # The net result of this conditional should be to preserve the dimensions.
  dim_old = dim(expression_by_cluster)
  if( normalize == "row" ){
    expression_by_cluster = apply(X = expression_by_cluster, FUN = norm_fun, MARGIN = 2 )  
  } else if( normalize == "column" ){
    expression_by_cluster = apply(X = expression_by_cluster, FUN = norm_fun, MARGIN = 1 ) %>% t 
  } else if( normalize == "both" ){
    expression_by_cluster = apply(X = expression_by_cluster, FUN = norm_fun, MARGIN = 1 ) %>% t  
    expression_by_cluster = apply(X = expression_by_cluster, FUN = norm_fun, MARGIN = 2 ) 
  } else if( normalize != "none"){
    warning('normalize should be one of "row", "column", "both", or "none". Performing row normalization.')
    normalize = "row"
  atat(all(dim(expression_by_cluster) == dim_old))
  expression_by_cluster = t(expression_by_cluster) 
  # # Stop and return data if desired
  if( return_type == "table" ){ return(expression_by_cluster) }
  # # Form matrix in shape of heatmap and then melt into ggplot
  plot_df_wide = cbind( as.data.frame( expression_by_cluster ) , gene = rownames(expression_by_cluster))
  plot_df_wide$y = 1:nrow(plot_df_wide)
  plot_df_long = reshape2::melt( plot_df_wide, 
                                 id.vars = c("gene", "y"), 
                                 value.name = "RelLogExpr")
  plot_df_long$Cluster = factor( as.character( plot_df_long$variable ),
                                 levels = desired_cluster_order )
  plot_df_long$gene = factor( as.character( plot_df_long$gene ),
                                 levels = genes_in_order )

  p = ggplot( plot_df_long ) + ggtitle( main ) +
    geom_tile( aes(x = Cluster, y = gene, fill = RelLogExpr ) )
  p = p + theme(axis.text.x = element_text(angle = 45, hjust = 1))

  # Set default labeling mode according to number of genes
  if( is.null( labels ) ){ 
    if( length( genes_in_order ) < 30 ){
      labels = "regular"
    } else if ( length( genes_in_order ) < 80 ){
      labels = "stagger"
    } else {
      labels = "none"
  # # Remove labels if desired
  if( labels=="none"){
    p = p + theme(axis.ticks.y=element_blank(), axis.text.y = element_blank()) 
  # # Make a DF with labeling info
  label_df = plot_df_long
  do_stagger = (labels == "stagger")
  label_df$horiz_pos = rep( c(-0.75*do_stagger, 0), length.out = nrow( plot_df_long ) )
  label_df %<>% extract(c("horiz_pos", "y", "gene")) 
  label_df = label_df[!duplicated(label_df$gene), ]
  invisible_label_row = label_df[1, ]
  invisible_label_row$gene = ""
  invisible_label_row$horiz_pos = -0.75 + min(label_df$horiz_pos)
  invisible_label_row$y = 1
  label_df = rbind(invisible_label_row, label_df)
  # # Add staggered labels with an invisible one farther out to make room (if desired)
  if( do_stagger ){
    p = p + theme(axis.ticks.y=element_blank(), axis.text.y = element_blank())
    p = p + geom_text(data = label_df, aes(x = horiz_pos, y = y, label = gene ))
  # # Add only labels for genes_to_label (if desired)
  if( !is.null(genes_to_label) ){
    label_df$horiz_pos = 0.5
    label_df$horiz_pos[[1]] = -0.5
    label_df$gene[!(label_df$gene %in% genes_to_label)] = ""
    label_df = label_df[!duplicated(label_df$gene), ]
    label_layer = geom_text(data = label_df, aes(x = horiz_pos, y = y, label = gene )) 
    # # Attempt with ggrepel. Unsolved issue: labels migrate on top of heatmap.
    # label_layer = tryCatch( expr = { ggrepel::geom_label_repel(data = label_df, aes(x = horiz_pos, y = y, label = gene )) },
    #                         error = function(e){     geom_text(data = label_df, aes(x = horiz_pos, y = y, label = gene )) } )
    p = p + label_layer 

  return( p )

## ------------------------------------------------------------------------
#' @export
screen_receptor_ligand = function( is_expressed, results_path ){
  # # Get receptor-ligand pairs; annotate with tissues expressed; save
  ramilowski = get_ramilowski()
  ramilowski$Ligand.ApprovedSymbol = NULL
  ramilowski$Receptor.ApprovedSymbol = NULL
  ramilowski = subset( ramilowski, ligand_mouse %in% rownames(is_expressed) )
  ramilowski = subset( ramilowski, receptor_mouse %in% rownames(is_expressed) )
  ramilowski$ligand_cell_types = 
    is_expressed[ramilowski$ligand_mouse, ] %>% 
    apply( 1, which ) %>% 
    sapply( names ) %>% 
    sapply( paste, collapse = "_")
  ramilowski$receptor_cell_types = 
    is_expressed[ramilowski$receptor_mouse, ] %>% 
    apply( 1, which ) %>% 
    sapply( names ) %>% 
    sapply( paste, collapse = "_")
  write.table( ramilowski, file =  file.path( results_path, "Receptor_ligand_all.txt" ), 
               sep = "\t", row.names = F, col.names = T, quote = F )

  absent = union( ramilowski$ligand_mouse, ramilowski$receptor_mouse ) %>% setdiff( rownames( is_expressed ) )
  if( length( absent ) > 0 ){
    zeropad = matrix(F, ncol = is_expressed, nrow = length( absent ), 
                     dimnames = list( gene = absent, 
                                      celltype = colnames(is_expressed)) )
    is_expressed %<>% rbind( zeropad )
  # # Get lists of receptors and ligands for each tissue pairing
  num_unique_ligands = matrix( 0, nrow = 3, ncol = 3, 
                               dimnames = list( lig_expr_tissue = c("BLD", "MES", "TEC"),
                                                rec_expr_tissue = c("BLD", "MES", "TEC") ) )
  dir.create.nice( file.path( results_path, "ligand_lists" ) )
  dir.create.nice( file.path( results_path, "receptor_lists" ) )
  for( lig_tissue in rownames(num_unique_ligands)){
    for( rec_tissue in colnames(num_unique_ligands)){
      eligible_subset = subset( ramilowski, 
                                is_expressed[ligand_mouse,   lig_tissue] & 
                                  is_expressed[receptor_mouse, rec_tissue] )
      num_unique_ligands[lig_tissue, rec_tissue] = length( unique( eligible_subset$ligand_mouse ) )
      write.table( unique(eligible_subset$ligand_mouse  ), row.names = F, col.names = F, quote = F,
                   paste0( results_path, "/ligand_lists/"  , lig_tissue, "_to_", rec_tissue, ".txt") )
      write.table( unique(eligible_subset$receptor_mouse), row.names = F, col.names = F, quote = F,
                   paste0( results_path, "/receptor_lists/", lig_tissue, "_to_", rec_tissue, ".txt") )
  # Background lists composed of everything that got successfully converted to mouse ortholog
  # For pathway analysis with a background list
  ramilowski_orig = read.table( file.path( PATH_TO_TABLES, "LigandReceptor_Ramilowski2015_mouse.txt" ), 
                                header = T, sep="\t", stringsAsFactors = F )
  write.table( ramilowski_orig$ligand_mouse   %>% unique, 
               paste0( results_path, "/ligand_lists/background.txt"),
               row.names = F, col.names = F, quote = F)
  write.table( ramilowski_orig$receptor_mouse %>% unique, 
               paste0( results_path, "/receptor_lists/background.txt"),
               row.names = F, col.names = F, quote = F)
  # # Save summary to file
  # # Sink helps get the full dimnames
  sink( file.path( results_path, "num_unique_ligands.txt" ) )
    print( num_unique_ligands )

## ------------------------------------------------------------------------
#' Quickly explore many parameter settings
#' @export
explore_embeddings = function(dge, results_path, all_params, test_mode = F){
  atat(is.data.frame( all_params ) )
    all_params[["var_gene_method"]] = "seurat"
    all_params[["cc_method"]] = "average"
    all_params[["plot_all_var_genes"]] = FALSE
  required_params = c( "var_gene_method",
                       "plot_all_var_genes" )
  atat( all( required_params %in% names( all_params ) ) )
  atat( any( c( "excess_var_cutoff","log_expr_cutoff", "prop_genes_to_select" ) %in% names( all_params ) ) )
  # # Record the things you're gonna try
  dir.create.nice( results_path )
  write.table( all_params, file = file.path(results_path, "params_to_try.txt"), 
               quote = F, sep = "\t", row.names = F)
  # # Try all the things
  prev_param_row = NULL
  for(i in rownames( all_params ) ){
    param_row = all_params[i, ]; names(param_row) = names(all_params)
    print("Trying these settings:")
    rp_mini = file.path(results_path, collapse_by_name( all_params[i,] ))

    # # remove cc variation
    if( "extra_regressout" %in% names( param_row ) ){
      extra_vars = trimws( strsplit( param_row[["extra_regressout"]], "," )[[1]] )
    } else {
      extra_vars = c()
    if(is.null(prev_param_row) || param_row[["cc_method"]] != prev_param_row[["cc_method"]]){
      dge = add_cc_score(dge, method = param_row[["cc_method"]])
      cc_score_names = get_macosko_cc_genes() %>% colnames
      dge = Seurat::RegressOut(object = dge, 
                               latent.vars = c(cc_score_names, extra_vars))
    # # select genes; do dim red; cluster cells
    dge = var_gene_select( dge, results_path = rp_mini, test_mode,
                           excess_var_cutoff   = param_row[["excess_var_cutoff"]],
                           log_expr_cutoff     = param_row[["log_expr_cutoff"]],
                           prop_genes_to_select = param_row[["prop_genes_to_select"]],
                           method = param_row[["var_gene_method"]])
    if( !is.null( param_row[[ "TF_only" ]] ) && param_row[[ "TF_only" ]] ){
      dge@var.genes %<>% intersect(get_mouse_tfs())

    dge = Seurat::PCA(dge, pc.genes = dge@var.genes, do.print = F) 
    pc.use = 1:param_row[["num_pc"]]
    dge = Seurat::RunTSNE(dge, dims.use = pc.use, do.fast = T) 
    dge = cluster_wrapper(dge, results_path = rp_mini, test_mode = test_mode, 
                          method = param_row[["clust_method"]],
                          granularities_as_string = param_row[["clust_granularities_as_string"]],
                          pc.use = 1:param_row[["num_pc"]])
    # # save plots and summaries 
    misc_summary_info( dge, results_path = rp_mini)
    saveRDS( dge, file.path( rp_mini, "dge.data") ) 
    top_genes_by_pc(   dge, results_path = rp_mini, test_mode)
    fm = param_row[["find_markers_thresh"]]
    if( !is.null( fm ) ){
      de_genes = find_de(dge, results_path = rp_mini, ident.use = "ident", thresh.use = fm )
      top_markers = get_top_de_genes(de_genes, results_path = rp_mini, 
                                     test_mode = test_mode, 
                                     top_n = 10, thresh_df = NULL)
      save_feature_plots(dge, results_path = rp_mini, gene_list = top_markers$gene, gene_list_name = "top_markers")
      # save_heatmap(dge, results_path = rp_mini, marker_info = de_genes,    main = "heatmap_all_de_genes")
      save_heatmap(dge, results_path = rp_mini, marker_info = top_markers, main = "heatmap_top_markers")
    #save_feature_plots(dge, results_path = rp_mini)
    pavg = param_row[["plot_all_var_genes"]]
    if( !is.null( pavg ) && pavg  ){
      save_feature_plots(dge, results_path = rp_mini, gene_list = dge@var.genes, gene_list_name = "var_genes")
    prev_param_row = param_row

## ------------------------------------------------------------------------

#' Quickly compute summary statistics for all genes.
#' @param dge Seurat object
#' @param ident.use Any factor variable available via FetchData(dge). 
#' @param aggregator Function used to aggregate within clusters. Try `prop_nz`.
#' @param genes_per_cluster Limit on number of genes returned per cluster.
#' @export
DESummaryFast = function( dge, ident.use = "ident", aggregator = mean, genes_per_cluster = NULL ){
  cluster_means = aggregate_nice( t(as.matrix(dge@data)), by = Seurat::FetchData(dge, ident.use), FUN = aggregator ) 
  nwm = function(x)(names(which.max(x)))
  cluster = apply( cluster_means, 2, nwm ) 
  max_pairwise_diff = apply( cluster_means, 2, max ) - apply( cluster_means, 2, min )
  rest_mean = function(x) mean(sort(x, decreasing = T)[-1])
  avg_diff = apply( cluster_means, 2, max ) - apply( cluster_means, 2, rest_mean )
  big_list = data.frame( cluster = cluster, 
                          avg_diff = avg_diff, 
                          max_pairwise_diff = max_pairwise_diff, 
                          gene = colnames( cluster_means ) ) 
  big_list = big_list[order(big_list$cluster, -big_list$avg_diff), ]
  if( !is.null( genes_per_cluster ) ){
    small_list = c()
    for( cluster in unique( big_list$cluster )){
      this_cluster_info = big_list[big_list$cluster==cluster, ]
      this_cluster_info = this_cluster_info[1:genes_per_cluster, ]
      small_list = rbind(small_list, this_cluster_info)
    return( small_list )
  } else {
    return( big_list )

#' Imitate Seurat::FindClusters but using k-means with the gap statistic.
#' @param dge Seurat object
#' @param nclust NULL by default, so chosen via gap statistic.
#' @param pc.use 
#' @export
KMeansGapStat = function( dge, nclust = NULL, pc.use = 1:25, iter.max = 20 ){

  kmeans_features = Seurat::FetchData( dge, paste0("PC", pc.use) )

  # Set num clusters via gap statistic
  if( is.null( nclust ) ){
    gap_stats = cluster::clusGap( x = kmeans_features, FUNcluster = kmeans, K.max = 25, spaceH0 = "scaledPCA", 
                                  iter.max = iter.max, B = 50 )
    plot( gap_stats )
    nclust = max(gap_stats$Tab[, 3]) 
  # Cluster and add output to Seurat object
  kmod = kmeans( kmeans_features, centers = nclust, iter.max = iter.max )
  dge %<>% Seurat::AddMetaData( setNames( kmod$cluster, rownames(dge@data.info) ), "kmeans_cluster" )
  dge %<>% Seurat::SetIdent( ident.use = kmod$cluster )
  return( dge )

#' Update the `@var.genes` slot of a Seurat object using cluster markers.
#' @param dge Seurat object
#' @param eligible_genes dge@var.genes get intersected with this list before return.
#' @param log_fc For each gene, we compute the difference between the highest and lowest cluster mean. 
#' If it exceeds this, the gene is included in the active set.
#' @return A Seurat object.
#' @export
RefineVariableGenes = function( dge, log_fc = 1, eligible_genes = NULL ){
    warning( "Only one identity class present! Taking top genes from PCA." )
    dge@var.genes = subset( dge@pca.x[, 1, drop = F], abs(PC1) > 0.05 ) %>% rownames
  } else {
    cluster_means = aggregate_nice( t(as.matrix(dge@data)), by = dge@ident, FUN = mean )
    max_pairwise_diffs = apply( cluster_means, 2, max ) - apply( cluster_means, 2, min )
    df = data.frame( avg_diff = max_pairwise_diffs, gene = names( max_pairwise_diffs ) )
    dge@var.genes = subset( df, avg_diff > log_fc, select = "gene", drop = T ) 
    dge@var.genes %<>% intersect( eligible_genes )
  return( dge )

#' Explore a single-cell dataset, alternating between gene selection and clustering.
#' @param dge Seurat object.
#' @param log_fc Passed to RefineVariableGenes.
#' @param eligible_genes Restricts the set of genes used. Try setting it to `get_mouse_tfs()`.
#' @param maxiter Iteration limit. Issues warning if hit.
#' @param tol If the list changes by less than this many genes (added + removed), it stops.
#' @param cluster_method A user-provided function to re-estimate clusters, modeled off Seurat::FindClusters. 
#'      It should take a Seurat object as the first arg.
#'      It should also accept a `pc.use` input (even if you just discard it).
#'      For your convenience, the principal components get updated before this function is called.
#'      The TSNE embedding doesn't, so don't use DBClustDimension alone.
#' @param ... Extra parameters passed to `cluster_method`
#' @param num_pc Used in recomputing the number of PCs, and `1:num_pc` is passed to `cluster_method` as `pc.use`.
#' @return List with names "dge" (the updated seurat object) and "iteration_stats" (numbers of genes present, added, and removed at each iteration.)
#' @export
ExploreIteratively = function( dge, log_fc = 0.25, 
                               eligible_genes = NULL,
                               maxiter = 5, tol = 10, verbose = T,
                               cluster_method = function(...) Seurat::FindClusters(..., print.output = F),
                               num_pc = 25, ... ){
  # Initialize genes
  if( 0==length( dge@var.genes ) ){
    dge = MeanVarPlot(dge)
  # Store info re: geneset convergence
  iteration_stats = data.frame( total   = rep(NA, maxiter),
                                removed = rep(NA, maxiter),
                                added   = rep(NA, maxiter) )
  for(i in 1:maxiter){
    # Update PCs
    if( length(dge@var.genes) < 100 ){
      dge = PCA(dge, do.print = F)
    } else {
      dge = PCAFast(dge, do.print = F, pcs.compute = num_pc)
    # Cluster
    dge = cluster_method( dge, pc.use = 1:num_pc, ... )
    # Reselect genes
    old_genes = dge@var.genes
    dge = RefineVariableGenes( dge, log_fc = log_fc, eligible_genes = eligible_genes )
    iteration_stats$total[i]   = dge@var.genes %>% length
    iteration_stats$removed[i] = setdiff( old_genes, dge@var.genes ) %>% length
    iteration_stats$added[i]   = setdiff( dge@var.genes, old_genes ) %>% length
    # Exit if list changes by fewer than `tol` genes
    if( iteration_stats$added[i] + iteration_stats$removed[i] < tol ){
      print(paste("Finished", i, "of", maxiter))
      print( "Current iteration_stats:" )
      print( iteration_stats[i, ] )
    if(i==maxiter){ warning("Iteration limit reached.")}
    print("Re-computing t-SNE embedding.")
  dge %<>% RunTSNE(dims.use = 1:num_pc, do.fast = T)
  return( list(dge = dge, iteration_stats = iteration_stats) )
