R/findMarkersTree.R

Defines functions .removeZeroVariance plotMarkerHeatmap .highlightClassLabel plotMarkerDendro .mapClass2features .getPerformance subUnderscore ncharX .convertToDendrogram .summarizeTree .predictClass getDecisions .addAlternativeSplit .getSplit .infoGainDensity .psdet .splitMetricIGpIGd .splitMetricModF1 .splitMetricPairwiseAUC .splitMetricRecursive .wrapSplitHybrid .wrapBranchHybrid .generateTreeList .getGeneAUC .createBranchPoints .findMarkersTree

Documented in getDecisions plotMarkerDendro plotMarkerHeatmap

#' @title Generate marker decision tree from single-cell clustering output
#' @description Create a decision tree that identifies gene markers for given
#'  cell populations. The algorithm uses a decision tree procedure to generate
#'  a set of rules for each cell cluster defined by single-cell clustering.
#'  Splits are determined by one of two metrics at each split: a one-off metric
#'  to determine rules for identifying clusters by a single feature, and a
#'  balanced metric to determine rules for identifying sets of similar clusters.
#' @param x A numeric \link{matrix} of counts or a
#'  \linkS4class{SingleCellExperiment}
#'  with the matrix located in the assay slot under \code{useAssay}.
#'  Rows represent features and columns represent cells.
#' @param useAssay A string specifying which \link{assay}
#'  slot to use if \code{x} is a
#'  \link[SingleCellExperiment]{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param class Vector of cell cluster labels.
#' @param oneoffMetric A character string. What one-off metric to run, either
#'  `modified F1` or `pairwise AUC`. Default is 'modified F1'.
#' @param metaclusters List where each element is a metacluster (e.g. known
#' cell type) and all the clusters within that metacluster (e.g. subtypes).
#' @param featureLabels  Vector of feature assignments, e.g. which cluster
#'  does each gene belong to? Useful when using clusters of features
#'  (e.g. gene modules or Seurat PCs) and user wishes to expand tree results
#'  to individual features (e.g. score individual genes within marker gene
#'  modules).
#' @param counts Numeric counts matrix. Useful when using clusters
#'  of features (e.g. gene modules) and user wishes to expand tree results to
#'  individual features (e.g. score individual genes within marker gene
#'  modules). Row names should be individual feature names. Ignored if
#'  \code{x} is a \linkS4class{SingleCellExperiment} object.
#' @param celda A \emph{celda_CG} or \emph{celda_C} object.
#'  Counts matrix has to be provided as well.
#' @param seurat A seurat object. Note that the seurat functions
#' \emph{RunPCA} and \emph{FindClusters} must have been run on the object.
#' @param threshold Numeric between 0 and 1. The threshold for the oneoff
#'  metric. Smaller values will result in more one-off splits. Default is 0.90.
#' @param reuseFeatures Logical. Whether or not a feature can be used more than
#'  once on the same cluster. Default is TRUE.
#' @param altSplit Logical. Whether or not to force a marker for clusters that
#'  are solely defined by the absence of markers. Default is TRUE.
#' @param consecutiveOneoff Logical. Whether or not to allow one-off splits at
#'  consecutive brances. Default is FALSE.
#' @param autoMetaclusters Logical. Whether to identify metaclusters prior to
#'  creating the tree based on the distance between clusters in a UMAP
#'  dimensionality reduction projection. A metacluster is simply a large
#'  cluster that includes several clusters within it. Default is TRUE.
#' @param seed Numeric. Seed used to enable reproducible UMAP results
#'  for identifying metaclusters. Default is 12345.
#' @param ... Ignored. Placeholder to prevent check warning.
#' @return A named list with six elements:
#' \itemize{
#'   \item rules - A named list with one data frame for every label. Each
#'  data frame has five columns and gives the set of rules for disinguishing
#'  each label.
#'   \itemize{
#'    \item feature - Marker feature, e.g. marker gene name.
#'    \item direction - Relationship to feature value. -1 if cluster is
#'    down-regulated for this feature, 1 if cluster is up-regulated.
#'    \item stat - The performance value returned by the splitting metric for
#'  this split.
#'    \item statUsed - Which performance metric was used. "Split" if information
#'  gain and "One-off" if one-off.
#'    \item level - The level of the tree at which is rule was defined. 1 is the
#'  level of the first split of the tree.
#'    \item metacluster - Optional. If metaclusters were used, the metacluster
#'     this rule is applied to.
#'   }
#'  \item dendro - A dendrogram object of the decision tree output. Plot with
#'  plotMarkerDendro()
#'  \item classLabels - A vector of the class labels used in the model, i.e.
#'   cell cluster labels.
#'  \item metaclusterLabels - A vector of the metacluster labels
#'   used in the model
#'  \item prediction - A character vector of label of predictions of the
#'  training data using the final model. "MISSING" if label prediction was
#'  ambiguous.
#'  \item performance - A named list denoting the training performance of the
#'  model:
#'  \itemize{
#'   \item accuracy - (number correct/number of samples) for the whole set of
#'  samples.
#'   \item balAcc - mean sensitivity across all clusters
#'   \item meanPrecision - mean precision across all clusters
#'   \item correct - the number of correct predictions of each cluster
#'   \item sizes - the number of actual counts of each cluster
#'   \item sensitivity - the sensitivity of the prediciton of each cluster
#'   \item precision - the precision of the prediciton of each cluster
#'  }
#' }
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class)
#'
#' # Plot dendrogram
#' plotMarkerDendro(DecTree)
#' }
#' @export
setGeneric("findMarkersTree", function(x, ...) {
    standardGeneric("findMarkersTree")})


#' @rdname findMarkersTree
#' @export
setMethod("findMarkersTree",
    signature(x = "SingleCellExperiment"),
    function(x,
        useAssay = "counts",
        altExpName = "featureSubset",
        class,
        oneoffMetric = c("modified F1", "pairwise AUC"),
        metaclusters,
        featureLabels,
        counts,
        seurat,
        threshold = 0.90,
        reuseFeatures = FALSE,
        altSplit = TRUE,
        consecutiveOneoff = FALSE,
        autoMetaclusters = TRUE,
        seed = 12345) {

        altExp <- SingleCellExperiment::altExp(x, altExpName)

        if ("celda_parameters" %in% names(S4Vectors::metadata(altExp))) {
            counts <- SummarizedExperiment::assay(altExp, i = useAssay)

            # factorize matrix (proportion of each module in each cell)
            features <- factorizeMatrix(x,
                useAssay = useAssay,
                altExpName = altExpName)$proportions$cell

            # get class labels
            class <- celdaClusters(x, altExpName = altExpName)

            # get feature labels
            featureLabels <- paste0("L",
                celdaModules(x, altExpName = altExpName))
        } else if (methods::hasArg(seurat)) {
            # get counts matrix from seurat object
            counts <- as.matrix(seurat@assays$RNA@data)

            # get class labels
            class <- as.character(Seurat::Idents(seurat))

            # get feature labels
            featureLabels <-
                unlist(apply(
                    seurat@reductions$pca@feature.loadings, 1,
                    function(x) {
                        return(names(x)[which(x == max(x))])
                    }
                ))

            # sum counts for each PC in each cell
            features <-
                matrix(
                    unlist(lapply(unique(featureLabels), function(pc) {
                        colSums(counts[featureLabels == pc, ])
                    })),
                    ncol = length(class),
                    byrow = TRUE,
                    dimnames = list(unique(featureLabels), colnames(counts))
                )

            # normalize column-wise (i.e. convert counts to proportions)
            features <- apply(features, 2, function(x) {
                x / sum(x)
            })
        }

        if (ncol(features) != length(class)) {
            stop("Number of columns of features must equal length of class")
        }

        if (any(is.na(class))) {
            stop("NA class values")
        }

        if (any(is.na(features))) {
            stop("NA feature values")
        }

        # Match the oneoffMetric argument
        oneoffMetric <- match.arg(oneoffMetric)

        branchPoints <- .findMarkersTree(features = features,
            class = class,
            oneoffMetric = oneoffMetric,
            metaclusters = metaclusters,
            featureLabels = featureLabels,
            counts = counts,
            seurat = seurat,
            threshold = threshold,
            reuseFeatures = reuseFeatures,
            altSplit = altSplit,
            consecutiveOneoff = consecutiveOneoff,
            autoMetaclusters = autoMetaclusters,
            seed = seed)

        return(branchPoints)
    }
)


#' @rdname findMarkersTree
#' @export
setMethod("findMarkersTree",
    signature(x = "matrix"),
    function(x,
        class,
        oneoffMetric = c("modified F1", "pairwise AUC"),
        metaclusters,
        featureLabels,
        counts,
        celda,
        seurat,
        threshold = 0.90,
        reuseFeatures = FALSE,
        altSplit = TRUE,
        consecutiveOneoff = FALSE,
        autoMetaclusters = TRUE,
        seed = 12345) {

        features <- x

        if (methods::hasArg(celda)) {
            # check that counts matrix is provided
            if (!methods::hasArg(counts)) {
                stop("Please provide counts matrix in addition to",
                    " celda object.")
            }

            # factorize matrix (proportion of each module in each cell)
            features <- factorizeMatrix(counts, celda)$proportions$cell

            # get class labels
            class <- celdaClusters(celda)$z

            # get feature labels
            featureLabels <- paste0("L", celdaClusters(celda)$y)
        } else if (methods::hasArg(seurat)) {
            # get counts matrix from seurat object
            counts <- as.matrix(seurat@assays$RNA@data)

            # get class labels
            class <- as.character(Seurat::Idents(seurat))

            # get feature labels
            featureLabels <-
                unlist(apply(
                    seurat@reductions$pca@feature.loadings, 1,
                    function(x) {
                        return(names(x)[which(x == max(x))])
                    }
                ))

            # sum counts for each PC in each cell
            features <-
                matrix(
                    unlist(lapply(unique(featureLabels), function(pc) {
                        colSums(counts[featureLabels == pc, ])
                    })),
                    ncol = length(class),
                    byrow = TRUE,
                    dimnames = list(unique(featureLabels), colnames(counts))
                )

            # normalize column-wise (i.e. convert counts to proportions)
            features <- apply(features, 2, function(x) {
                x / sum(x)
            })
        }

        if (ncol(features) != length(class)) {
            stop("Number of columns of features must equal length of class")
        }

        if (any(is.na(class))) {
            stop("NA class values")
        }

        if (any(is.na(features))) {
            stop("NA feature values")
        }

        # Match the oneoffMetric argument
        oneoffMetric <- match.arg(oneoffMetric)

        branchPoints <- .findMarkersTree(features = features,
            class = class,
            oneoffMetric = oneoffMetric,
            metaclusters = metaclusters,
            featureLabels = featureLabels,
            counts = counts,
            seurat = seurat,
            threshold = threshold,
            reuseFeatures = reuseFeatures,
            altSplit = altSplit,
            consecutiveOneoff = consecutiveOneoff,
            autoMetaclusters = autoMetaclusters,
            seed = seed)

        return(branchPoints)
    }
)


.findMarkersTree <- function(features,
    class,
    oneoffMetric,
    metaclusters,
    featureLabels,
    counts,
    seurat,
    threshold,
    reuseFeatures,
    altSplit,
    consecutiveOneoff,
    autoMetaclusters,
    seed) {

  # Transpose features
  features <- t(features)

  # If no detailed cell types are provided or to be identified
  if (!methods::hasArg(metaclusters) & (!autoMetaclusters)) {
    message("Building tree...")

    # Set class to factor
    class <- as.factor(class)

    # Generate list of tree levels
    tree <- .generateTreeList(
      features,
      class,
      oneoffMetric,
      threshold,
      reuseFeatures,
      consecutiveOneoff
    )

    # Add alternative node for the solely down-regulated leaf
    if (altSplit) {
      tree <- .addAlternativeSplit(tree, features, class)
    }

    message("Computing performance metrics...")

    # Format tree output for plotting and generate summary statistics
    DTsummary <- .summarizeTree(tree, features, class)

    # Remove confusing 'value' column
    DTsummary$rules <- lapply(DTsummary$rules, function(x) {
      x["value"] <- NULL
      x
    })

    # Add column to each rules table which specifies its class
    DTsummary$rules <- mapply(cbind,
      "class" = as.character(names(DTsummary$rules)),
      DTsummary$rules,
      SIMPLIFY = FALSE
    )

    # Generate table for each branch point in the tree
    DTsummary$branchPoints <-
      .createBranchPoints(DTsummary$rules)

    # Add class labels to output
    DTsummary$classLabels <- class

    return(DTsummary)
  } else {
    # If metaclusters are provided or to be identified

    # consecutive one-offs break the code(tricky to find 1st balanced split)
    if (consecutiveOneoff) {
      stop(
        "Cannot use metaclusters if consecutive one-offs are allowed.",
        " Please set the consecutiveOneoff parameter to FALSE."
      )
    }

    # Check if need to identify metaclusters
    if (autoMetaclusters & !methods::hasArg(metaclusters)) {
      message("Identifying metaclusters...")

      # if seurat object then use seurat's UMAP parameters
      if (methods::hasArg(seurat)) {
        suppressMessages(seurat <-
          Seurat::RunUMAP(
            seurat,
            dims = seq(ncol(seurat@reductions$pca@feature.loadings))
          ))
        umap <- seurat@reductions$umap@cell.embeddings
      }
      else {
        if (is.null(seed)) {
          umap <- uwot::umap(
            t(sqrt(t(features))),
            n_neighbors = 15,
            min_dist = 0.01,
            spread = 1,
            n_sgd_threads = 1
          )
        }
        else {
          withr::with_seed(
            seed,
            umap <- uwot::umap(
              t(sqrt(t(features))),
              n_neighbors = 15,
              min_dist = 0.01,
              spread = 1,
              n_sgd_threads = 1
            )
          )
        }
      }
      # dbscan to find metaclusters
      dbscan <- dbscan::dbscan(umap, eps = 1)

      # place each population in the correct metacluster
      mapping <-
        unlist(lapply(
          sort(as.integer(
            unique(class)
          )),
          function(population) {
            # get indexes of occurences of this population
            indexes <-
              which(class == population)

            # get corresponding metaclusters
            metaIndices <-
              dbscan$cluster[indexes]

            # return corresponding metacluster with majority vote
            return(names(sort(table(
              metaIndices
            ), decreasing = TRUE)[1]))
          }
        ))

      # create list which will contain subtypes of each metacluster
      metaclusters <- vector(mode = "list")

      # fill in list of populations for each metacluster
      for (i in unique(mapping)) {
        metaclusters[[i]] <-
          sort(as.integer(unique(class)))[which(mapping == i)]
      }
      names(metaclusters) <- paste0("M", unique(mapping))

      message(paste("Identified", length(metaclusters), "metaclusters"))
    }

    # Check that cell types match class labels
    if (mean(unlist(metaclusters) %in% unique(class)) != 1) {
      stop(
        "Provided cell types do not match class labels. ",
        "Please check the 'metaclusters' argument."
      )
    }

    # Create vector with metacluster labels
    metaclusterLabels <- class
    for (i in names(metaclusters)) {
      metaclusterLabels[metaclusterLabels %in% metaclusters[[i]]] <- i
    }

    # Rename metaclusters with just one cluster
    oneCluster <-
      names(metaclusters[lengths(metaclusters) == 1])
    if (length(oneCluster) > 0) {
      oneClusterIndices <- which(metaclusterLabels %in% oneCluster)
      metaclusterLabels[oneClusterIndices] <-
        paste0(
          metaclusterLabels[oneClusterIndices], "(",
          class[oneClusterIndices], ")"
        )
      names(metaclusters[lengths(metaclusters) == 1]) <-
        paste0(
          names(metaclusters[lengths(metaclusters) == 1]), "(",
          unlist(metaclusters[lengths(metaclusters) == 1]), ")"
        )
    }

    # create temporary variables for top-level tree
    tmpThreshold <- threshold

    # create list to store split off classes at each threshold
    markerThreshold <- list()

    # Create top-level tree

    # while there is still a balanced split at the top-level
    while (TRUE) {
      # create tree
      message("Building top-level tree across all metaclusters...")
      tree <-
        .generateTreeList(
          features,
          as.factor(metaclusterLabels),
          oneoffMetric,
          tmpThreshold,
          reuseFeatures,
          consecutiveOneoff
        )

      # Add alternative node for the solely down-regulated leaf
      tree <- .addAlternativeSplit(
        tree, features,
        as.factor(metaclusterLabels)
      )

      # store clusters with markers at current threshold
      topLevel <- tree[[1]][[1]]
      if (topLevel$statUsed == "One-off") {
        markerThreshold[[as.character(tmpThreshold)]] <-
          unlist(lapply(
            topLevel[seq(length(topLevel) - 3)],
            function(marker) {
              return(marker$group1Consensus)
            }
          ))
      }

      # if no more balanced split
      if (length(tree) == 1) {
        # if all clusters have positive markers
        if (length(tree[[1]][[1]]) == (length(metaclusters) + 3)) {
          break
        }
        else {
          # decrease threshold by 10%
          tmpThreshold <- tmpThreshold * 0.9
          message("Decreasing classifier threshold to ", tmpThreshold)
          next
        }
      }
      # still balanced split
      else {
        # get up-regulated clusters at first balanced split
        upClass <- tree[[2]][[1]][[1]]$group1Consensus

        # if only 2 clusters at the balanced split then merge them
        if ((length(upClass) == 1) &&
          (length(tree[[2]][[1]][[1]]$group2Consensus) == 1)) {
          upClass <- c(upClass, tree[[2]][[1]][[1]]$group2Consensus)
        }

        # update metacluster label of each cell
        tmpMeta <- metaclusterLabels
        tmpMeta[tmpMeta %in% upClass] <-
          paste(upClass, sep = "", collapse = "+")


        # create top-level tree again
        tmpTree <-
          .generateTreeList(
            features,
            as.factor(tmpMeta),
            oneoffMetric,
            tmpThreshold,
            reuseFeatures,
            consecutiveOneoff
          )

        # Add alternative node for the solely down-regulated leaf
        tmpTree <- .addAlternativeSplit(
          tmpTree, features,
          as.factor(tmpMeta)
        )

        # if new tree still has balanced split/no markers for some
        if ((length(tmpTree) > 1) ||
          (length(tree[[1]][[1]]) != (length(metaclusters) + 3))) {
          # decrease threshold by 10%
          tmpThreshold <- tmpThreshold * 0.9
          message("Decreasing classifier threshold to ", tmpThreshold)
        }
        else {
          # set final metacluster labels to new set of clusters
          metaclusterLabels <- tmpMeta

          # set final tree to current tree
          tree <- tmpTree

          ## update 'metaclusters' (list of metaclusters)
          # get celda clusters in these metaclusters
          newMetacluster <- unlist(metaclusters[upClass])
          # remove old metaclusters
          metaclusters[upClass] <- NULL
          # add new metacluster to list of metaclusters
          metaclusters[paste(upClass, sep = "", collapse = "+")] <-
            list(unname(newMetacluster))

          break
        }
      }
    }

    # re-format output
    finalTree <- tree
    tree <- list(rules = .mapClass2features(
      finalTree,
      features,
      as.factor(metaclusterLabels),
      topLevelMeta = TRUE
    )$rules)

    # keep markers at first threshold they reached only
    markersToRemove <- c()
    for (thresh in names(markerThreshold)) {
      thresholdClasses <- markerThreshold[[thresh]]
      for (cl in thresholdClasses) {
        curRules <- tree$rules[[cl]]
        lowMarkerIndices <- which(curRules$direction == 1 &
          curRules$stat < as.numeric(thresh))
        if (length(lowMarkerIndices) > 0 &
          length(which(curRules$direction == 1)) > 1) {
          markersToRemove <- c(
            markersToRemove,
            curRules[lowMarkerIndices, "feature"]
          )
        }
      }
    }
    tree$rules <- lapply(tree$rules, function(rules) {
      return(rules[!rules$feature %in% markersToRemove, ])
    })

    # store final set of top-level markers
    topLevelMarkers <-
      unlist(lapply(tree$rules, function(cluster) {
        markers <- cluster[cluster$direction == 1, "feature"]
        return(paste(markers, collapse = ";"))
      }))

    # create tree dendrogram
    tree$dendro <-
      .convertToDendrogram(finalTree, as.factor(metaclusterLabels),
        splitNames = topLevelMarkers
      )

    # add metacluster label to rules table
    for (metacluster in names(tree$rules)) {
      tree$rules[[metacluster]]$metacluster <- metacluster
    }

    # Store tree's dendrogram in a separate variable
    dendro <- tree$dendro

    # Find which metaclusters have more than one cluster
    largeMetaclusters <-
      names(metaclusters[lengths(metaclusters) > 1])

    # Update subtype labels for large metaclusters
    subtypeLabels <- metaclusterLabels
    subtypeLabels[subtypeLabels %in% largeMetaclusters] <-
      paste0(
        subtypeLabels[subtypeLabels %in% largeMetaclusters],
        "(",
        class[subtypeLabels %in% largeMetaclusters],
        ")"
      )

    # Update metaclusters list
    for (metacluster in names(metaclusters)) {
      subtypes <- metaclusters[metacluster]
      subtypes <- lapply(subtypes, function(subtype) {
        paste0(metacluster, "(", subtype, ")")
      })
      metaclusters[metacluster] <- subtypes
    }

    # Create separate trees for each cell type with more than one cluster
    newTrees <- lapply(largeMetaclusters, function(metacluster) {
      # Print current status
      message("Building tree for metacluster ", metacluster)

      # Remove used features
      featUse <- colnames(features)
      if (!reuseFeatures) {
        tmpRules <- tree$rules[[metacluster]]
        featUse <-
          featUse[!featUse %in%
            tmpRules[tmpRules$direction == 1, "feature"]]
      }

      # Create new tree
      newTree <-
        .generateTreeList(
          features[metaclusterLabels == metacluster, featUse],
          as.factor(subtypeLabels[metaclusterLabels == metacluster]),
          oneoffMetric,
          threshold,
          reuseFeatures,
          consecutiveOneoff
        )

      # Add alternative node for the solely down-regulated leaf
      if (altSplit) {
        newTree <-
          .addAlternativeSplit(
            newTree,
            features[metaclusterLabels == metacluster, featUse],
            as.factor(subtypeLabels[metaclusterLabels == metacluster])
          )
      }

      newTree <- list(
        rules = .mapClass2features(
          newTree,
          features[metaclusterLabels
          == metacluster, ],
          as.factor(subtypeLabels[metaclusterLabels == metacluster])
        )$rules,
        dendro = .convertToDendrogram(
          newTree,
          as.factor(subtypeLabels[metaclusterLabels ==
            metacluster])
        )
      )

      # Adjust 'rules' table for new tree
      newTree$rules <- lapply(newTree$rules, function(rules) {
        rules$level <- rules$level +
          max(tree$rules[[metacluster]]$level)
        rules$metacluster <- metacluster
        rules <- rbind(tree$rules[[metacluster]], rules)
      })

      return(newTree)
    })
    names(newTrees) <- largeMetaclusters

    # Fix max depth in original tree
    if (length(newTrees) > 0) {
      maxDepth <- max(unlist(lapply(newTrees, function(newTree) {
        lapply(newTree$rules, function(ruleDF) {
          ruleDF$level
        })
      })))
      addDepth <- maxDepth - attributes(dendro)$height

      dendro <- stats::dendrapply(dendro, function(node, addDepth) {
        if (attributes(node)$height > 1) {
          attributes(node)$height <- attributes(node)$height +
            addDepth + 1
        }
        return(node)
      }, addDepth)
    }

    # Find indices of cell type nodes in tree
    indices <- lapply(
      largeMetaclusters,
      function(metacluster) {
        # Initialize sub trees, indices string, and flag
        dendSub <- dendro
        index <- ""
        flag <- TRUE

        while (flag) {
          # Get the edge with the class of interest
          whEdge <- which(unlist(
            lapply(
              dendSub,
              function(edge) {
                metacluster %in%
                  attributes(edge)$classLabels
              }
            )
          ))

          # Add this as a string
          index <-
            paste0(index, "[[", whEdge, "]]")

          # Move to this branch
          dendSub <-
            eval(parse(text = paste0("dendro", index)))

          # Is this the only class in that branch
          flag <- length(attributes(dendSub)$classLabels) > 1
        }

        return(index)
      }
    )
    names(indices) <- largeMetaclusters

    # Add each cell type tree
    for (metacluster in largeMetaclusters) {
      # Get current tree
      metaclusterDendro <- newTrees[[metacluster]]$dendro

      # Adjust labels, member count, and midpoint of nodes
      dendro <- stats::dendrapply(dendro, function(node) {
        # Check if in right branch
        if (metacluster %in%
          as.character(attributes(node)$classLabels)) {
          # Replace cell type label with subtype labels
          labels <- attributes(node)$classLabels
          labels <- as.character(labels)
          labels <- labels[labels != metacluster]
          labels <- c(labels, unique(subtypeLabels)
          [grep(metacluster, unique(subtypeLabels))])
          attributes(node)$classLabels <- labels

          # Assign new member count for this branch
          attributes(node)$members <-
            length(attributes(node)$classLabels)

          # Assign new midpoint for this branch
          attributes(node)$midpoint <-
            (attributes(node)$members - 1) / 2
        }
        return(node)
      })

      # Replace label at new tree's branch point
      branchPointAttr <- attributes(eval(parse(text = paste0(
        "dendro", indices[[metacluster]]
      ))))
      branchPointLabel <- branchPointAttr$label
      branchPointStatUsed <- branchPointAttr$statUsed

      if (!is.null(branchPointLabel)) {
        attributes(metaclusterDendro)$label <- branchPointLabel
        attributes(metaclusterDendro)$statUsed <-
          branchPointStatUsed
      }

      # Fix height
      indLoc <-
        gregexpr("\\[\\[", indices[[metacluster]])[[1]]
      indLoc <- indLoc[length(indLoc)]
      parentIndexString <- substr(
        indices[[metacluster]],
        0,
        indLoc - 1
      )
      parentHeight <- attributes(eval(parse(
        text = paste0("dendro", parentIndexString)
      )))$height
      metaclusterHeight <-
        attributes(metaclusterDendro)$height
      metaclusterDendro <- stats::dendrapply(
        metaclusterDendro,
        function(node,
                 parentHeight,
                 metaclusterHeight) {
          if (attributes(node)$height > 1) {
            attributes(node)$height <-
              parentHeight - 1 -
              (metaclusterHeight -
                attributes(node)$height)
          }
          return(node)
        }, parentHeight, metaclusterHeight
      )

      # Add new tree to original tree
      eval(parse(text = paste0(
        "dendro", indices[[metacluster]], " <- metaclusterDendro"
      )))

      # Append new tree's 'rules' tables to original tree
      tree$rules <-
        append(tree$rules,
          newTrees[[metacluster]]$rules,
          after = which(names(tree$rules) == metacluster)
        )

      # Remove old tree's rules
      tree$rules <-
        tree$rules[-which(names(tree$rules) == metacluster)]
    }

    # Set final tree dendro
    tree$dendro <- dendro

    # Get performance statistics
    message("Computing performance statistics...")
    perfList <- .getPerformance(
      tree$rules,
      features,
      as.factor(subtypeLabels)
    )
    tree$prediction <- perfList$prediction
    tree$performance <- perfList$performance

    # Remove confusing 'value' column
    tree$rules <-
      lapply(tree$rules, function(x) {
        x["value"] <- NULL
        x
      })

    # add column to each rules table which specifies its class
    tree$rules <-
      mapply(cbind,
        "class" = as.character(names(tree$rules)),
        tree$rules,
        SIMPLIFY = FALSE
      )

    # create branch points table
    branchPoints <-
      .createBranchPoints(tree$rules, largeMetaclusters, metaclusters)

    # collapse all rules tables into one large table
    collapsed <- do.call("rbind", tree$rules)

    # get top-level rules
    topLevelRules <- collapsed[collapsed$level == 1, ]

    # add 'class' column
    topLevelRules$class <- topLevelRules$metacluster

    # add to branch point list
    branchPoints[["top_level"]] <- topLevelRules

    # check if need to expand features to gene-level
    if (methods::hasArg(featureLabels) &&
      methods::hasArg(counts)) {
      message("Computing scores for individual genes...")

      # make sure feature labels match those in the tree
      if (!all(unique(collapsed$feature) %in% unique(featureLabels))) {
        m <- "Provided feature labels don't match those in count matrix."
        stop(m)
      }

      # iterate over branch points
      branchPoints <- lapply(branchPoints, function(branch) {
        # iterate over unique features
        featAUC <-
          lapply(
            unique(branch$feature),
            .getGeneAUC,
            branch,
            subtypeLabels,
            metaclusterLabels,
            featureLabels,
            counts
          )

        # update branch table after merging genes data
        return(do.call("rbind", featAUC))
      })

      # simplify top-level in rules tables to only up-regulated markers
      tree$rules <- lapply(tree$rules, function(rule) {
        return(rule[-intersect(
          which(rule$level == 1),
          which(rule$direction == (-1))
        ), ])
      })

      ## add gene-level info to rules tables
      # collapse branch points tables into one
      collapsedBranches <- do.call("rbind", branchPoints)
      collapsedBranches$class <-
        as.character(collapsedBranches$class)

      # loop over rules tables and get relevant info
      tree$rules <- lapply(tree$rules, function(class) {
        # initialize table to return
        toReturn <- data.frame(NULL)

        # loop over rows of this class
        for (i in seq(nrow(class))) {
          # extract relevant genes from branch points tables
          genesAUC <- collapsedBranches[collapsedBranches$feature ==
            class$feature[i] &
            collapsedBranches$level == class$level[i] &
            collapsedBranches$class == class$class[i], ]

          # don't forget top-level
          if (class$level[i] == 1) {
            genesAUC <- collapsedBranches[collapsedBranches$feature ==
              class$feature[i] &
              collapsedBranches$level == class$level[i] &
              collapsedBranches$class == class$metacluster[i], ]
          }

          # merge table
          toReturn <- rbind(toReturn, genesAUC)
        }
        return(toReturn)
      })

      # remove table row names
      tree$rules <- lapply(tree$rules, function(t) {
        rownames(t) <- NULL
        return(t)
      })

      # add feature labels to output
      tree$featureLabels <- featureLabels
    }

    # simplify top-level branch point to save memory
    branchPoints$top_level <-
      branchPoints$top_level[branchPoints$top_level$direction == 1, ]
    branchPoints$top_level <-
      branchPoints$top_level[!duplicated(branchPoints$top_level), ]

    # remove branch points row names
    branchPoints <- lapply(branchPoints, function(br) {
      rownames(br) <- NULL
      return(br)
    })

    # adjust subtype labels
    branchPoints <- lapply(branchPoints, function(br) {
      br$class <- as.character(br$class)
      br$class[grepl("\\(.*\\)", br$class)] <- regmatches(
        br$class[grepl("\\(.*\\)", br$class)],
        regexpr(
          pattern = "(?<=\\().*?(?=\\)$)",
          br$class[grepl("\\(.*\\)", br$class)],
          perl = TRUE
        )
      )

      br$metacluster <- as.character(br$metacluster)
      br$metacluster[grepl("\\(.*\\)", br$metacluster)] <-
        gsub(
          "\\(.*\\)", "",
          br$metacluster[grepl("\\(.*\\)", br$metacluster)]
        )

      return(br)
    })
    # adjust subtype labels
    tree$rules <-
      suppressWarnings(lapply(tree$rules, function(r) {
        r$class <- as.character(r$class)
        r$class[grepl("\\(.*\\)", r$class)] <- regmatches(
          r$class[grepl("\\(.*\\)", r$class)],
          regexpr(
            pattern = "(?<=\\().*?(?=\\)$)",
            r$class[grepl("\\(.*\\)", r$class)],
            perl = TRUE
          )
        )

        r$metacluster[grepl("\\(.*\\)", r$metacluster)] <-
          gsub(
            "\\(.*\\)", "",
            r$metacluster[grepl("\\(.*\\)", r$metacluster)]
          )
        return(r)
      }))


    # add to tree
    tree$branchPoints <- branchPoints

    # return class labels
    tree$classLabels <- regmatches(
      subtypeLabels,
      regexpr(
        pattern = "(?<=\\().*?(?=\\)$)",
        subtypeLabels, perl = TRUE
      )
    )

    tree$metaclusterLabels <- metaclusterLabels
    tree$metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)] <-
      gsub(
        "\\(.*\\)", "",
        metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)]
      )

    # Final return
    return(tree)
  }
}


# helper function to create table for each branch point in the tree
.createBranchPoints <-
  function(rules, largeMetaclusters, metaclusters) {
    # First step differs if metaclusters were used

    if (methods::hasArg(metaclusters) &&
      (length(largeMetaclusters) > 0)) {
      # iterate over metaclusters and add the rules for each level
      branchPoints <-
        lapply(largeMetaclusters, function(metacluster) {
          # get names of subtypes
          subtypes <- metaclusters[[metacluster]]

          # collapse rules tables of subtypes
          subtypeRules <- do.call("rbind", rules[subtypes])

          # get rules at each level
          levels <-
            lapply(seq(2, max(subtypeRules$level)), function(level) {
              return(subtypeRules[subtypeRules$level == level, ])
            })
          names(levels) <- paste0(
            metacluster, "_level_",
            seq(max(subtypeRules$level) - 1)
          )

          return(levels)
        })
      branchPoints <- unlist(branchPoints, recursive = FALSE)
    }
    else {
      # collapse all rules into one table
      collapsed <- do.call("rbind", rules)

      # subset rules at each level
      branchPoints <-
        lapply(seq(max(collapsed$level)), function(level) {
          return(collapsed[collapsed$level == level, ])
        })
      names(branchPoints) <-
        paste0("level_", seq(max(collapsed$level)))
    }

    # split each level into its branch points
    branchPoints <- lapply(branchPoints, function(level) {
      # check if need to split
      firstFeat <- level$feature[1]
      firstStat <- level$stat[1]
      if (setequal(
        level[
          level$feature == firstFeat &
            level$stat == firstStat,
          "class"
        ],
        unique(level$class)
      )) {
        return(level)
      }

      # initialize lists for new tables
      bSplits <- NA
      oSplits <- NA

      # get balanced split rows by themselves
      balS <- level[level$statUsed == "Split", ]

      # return table for each unique value of 'stat'
      if (nrow(balS) > 0) {
        # get unique splits (based on stat)
        unS <- unique(balS$stat)

        # return table for each unique split
        bSplits <- lapply(unS, function(s) {
          balS[balS$stat == s, ]
        })
      }

      # get one-off rows by themselves
      oneS <- level[level$statUsed == "One-off", ]

      if (nrow(oneS) > 0) {
        # check if need to split
        firstFeat <- oneS$feature[1]
        if (setequal(
          oneS[oneS$feature == firstFeat, "class"],
          unique(oneS$class)
        )) {
          oSplits <- oneS
        }

        # get class groups for each marker
        markers <- oneS[oneS$direction == 1, "feature"]
        groups <- unique(unlist(lapply(markers, function(m) {
          return(paste(as.character(oneS[oneS$feature == m, "class"]),
            collapse = " "
          ))
        })))

        # return table for each class group
        oSplits <- lapply(groups, function(x) {
          gr <- unlist(strsplit(x, split = " "))
          oneS[as.character(oneS$class) %in% gr, ]
        })
      }

      # rename new tables
      if (is.list(bSplits)) {
        names(bSplits) <- paste0(
          "split_",
          LETTERS[seq(length(bSplits), 1)]
        )
      }
      if (is.list(oSplits)) {
        names(oSplits) <- paste0(
          "one-off_",
          LETTERS[seq(length(oSplits), 1)]
        )
      }

      # return 2 sets of table
      toReturn <- list(oSplits, bSplits)
      toReturn <- toReturn[!is.na(toReturn)]
      toReturn <- unlist(toReturn, recursive = FALSE)
      return(toReturn)
    })

    # adjust for new tables
    branchPoints <- lapply(branchPoints, function(br) {
      if (inherits(br, "list")) {
        return(br)
      }
      else {
        return(list(br))
      }
    })
    branchPoints <- unlist(branchPoints, recursive = FALSE)
    # replace dots in names of new branches with underscores
    names(branchPoints) <- gsub(
      pattern = "\\.([^\\.]*)$",
      replacement = "_\\1",
      names(branchPoints)
    )

    return(branchPoints)
  }

# helper function to get AUC for individual genes within feature
.getGeneAUC <- function(marker,
                        table,
                        subtypeLabels,
                        metaclusterLabels,
                        featureLabels,
                        counts) {
  # get up-regulated & down-regulated classes for this feature
  upClass <-
    as.character(table[table$feature == marker &
      table$direction == 1, "class"])
  downClasses <-
    as.character(table[table$feature == marker &
      table$direction == (-1), "class"])

  # subset counts matrix
  if (table$level[1] > 1) {
    subCounts <-
      counts[, which(subtypeLabels %in% c(upClass, downClasses))]
  }
  else {
    subCounts <- counts[, which(metaclusterLabels %in%
      c(upClass, downClasses))]
  }

  # subset class labels
  if (table$level[1] > 1) {
    subLabels <- subtypeLabels[which(subtypeLabels %in%
      c(upClass, downClasses))]
  }
  else {
    subLabels <- metaclusterLabels[which(metaclusterLabels %in%
      c(upClass, downClasses))]
  }

  # set label to 0 if not class of interest
  subLabels <- as.numeric(subLabels %in% upClass)

  # get individual features within this marker
  markers <- rownames(counts)[which(featureLabels == marker)]

  # get one-vs-all AUC for each gene
  auc <- unlist(lapply(markers, function(markerGene) {
    as.numeric(pROC::auc(
      pROC::roc(
        subLabels,
        subCounts[markerGene, ],
        direction = "<",
        quiet = TRUE
      )
    ))
  }))
  names(auc) <- markers

  # sort by AUC
  auc <- sort(auc, decreasing = TRUE)

  # create table for this marker
  featTable <- table[table$feature == marker, ]
  featTable <-
    featTable[rep(seq_len(nrow(featTable)), each = length(auc)), ]
  featTable$gene <-
    rep(names(auc), length(c(upClass, downClasses)))
  featTable$geneAUC <- rep(auc, length(c(upClass, downClasses)))

  # return table for merging with main table
  return(featTable)
}

# This function generates the decision tree by recursively separating classes.
.generateTreeList <- function(features,
                              class,
                              oneoffMetric,
                              threshold,
                              reuseFeatures,
                              consecutiveOneoff = FALSE) {
  # Initialize Tree
  treeLevel <- tree <- list()

  # Initialize the first split
  treeLevel[[1]] <- list()

  # Generate the first split at the first level
  treeLevel[[1]] <- .wrapSplitHybrid(
    features,
    class,
    threshold,
    oneoffMetric
  )

  # Add set of features used at this split
  treeLevel[[1]]$fUsed <- unlist(lapply(
    treeLevel[[1]][names(treeLevel[[1]]) != "statUsed"],
    function(X) {
      X$featureName
    }
  ))

  # Initialize split directions
  treeLevel[[1]]$dirs <- 1

  # Add split list as first level
  tree[[1]] <- treeLevel

  # Initialize tree depth
  mDepth <- 1

  # Build tree until all leafs are of a single cluster
  while (length(unlist(treeLevel)) > 0) {
    # Create list of branches on this level
    outList <-
      lapply(treeLevel, function(split, features, class) {
        # Check for consecutive oneoff
        tryOneoff <- TRUE
        if (!consecutiveOneoff & split$statUsed == "One-off") {
          tryOneoff <- FALSE
        }

        # If length(split == 4) than this split is binary node
        if (length(split) == 4 &
          length(split[[1]]$group1Consensus) > 1) {
          # Create branch from this split.
          branch1 <- .wrapBranchHybrid(
            split[[1]]$group1,
            features,
            class,
            split$fUsed,
            threshold,
            reuseFeatures,
            oneoffMetric,
            tryOneoff
          )

          if (!is.null(branch1)) {
            # Add feature to list of used features.
            branch1$fUsed <- c(split$fUsed, unlist(lapply(
              branch1[names(branch1) != "statUsed"],
              function(X) {
                X$featureName
              }
            )))

            # Add the split direction (always 1 when splitting group 1)
            branch1$dirs <- c(split$dirs, 1)
          }
        } else {
          branch1 <- NULL
        }

        # If length(split == 4) than this split is binary node
        if (length(split) == 4 &
          length(split[[1]]$group2Consensus) > 1) {
          # Create branch from this split
          branch2 <- .wrapBranchHybrid(
            split[[1]]$group2,
            features,
            class,
            split$fUsed,
            threshold,
            reuseFeatures,
            oneoffMetric,
            tryOneoff
          )

          if (!is.null(branch2)) {
            # Add feature to list of used features.
            branch2$fUsed <- c(split$fUsed, unlist(lapply(
              branch2[names(branch2) != "statUsed"],
              function(X) {
                X$featureName
              }
            )))

            # Add the split direction (always 2 when splitting group 2)
            branch2$dirs <- c(split$dirs, 2)
          }

          # If length(split > 4) than this split is more than 2 edges
          # In this case group 1 will always denote leaves.
        } else if (length(split) > 4) {
          # Get samples that are never in group 1 in this split
          group1Samples <- unique(unlist(lapply(
            split[!names(split) %in% c("statUsed", "fUsed", "dirs")],
            function(X) {
              X$group1
            }
          )))
          group2Samples <- unique(unlist(lapply(
            split[!names(split) %in% c("statUsed", "fUsed", "dirs")],
            function(X) {
              X$group2
            }
          )))
          group2Samples <- group2Samples[!group2Samples %in%
            group1Samples]

          # Check that there is still more than one class
          group2Classes <- levels(droplevels(class[rownames(features) %in%
            group2Samples]))
          if (length(group2Classes) > 1) {
            # Create branch from this split
            branch2 <- .wrapBranchHybrid(
              group2Samples,
              features,
              class,
              split$fUsed,
              threshold,
              reuseFeatures,
              oneoffMetric,
              tryOneoff
            )

            if (!is.null(branch2)) {
              # Add multiple features
              branch2$fUsed <-
                c(split$fUsed, unlist(lapply(
                  branch2[names(branch2) != "statUsed"],
                  function(X) {
                    X$featureName
                  }
                )))

              # Instead of 2, this direction is 1 + the num. splits
              branch2$dirs <- c(
                split$dirs,
                sum(!names(split) %in%
                  c("statUsed", "fUsed", "dirs")) + 1
              )
            }
          } else {
            branch2 <- NULL
          }
        } else {
          branch2 <- NULL
        }

        # Combine these branches
        outBranch <- list(branch1, branch2)

        # Only keep non-null branches
        outBranch <-
          outBranch[!unlist(lapply(outBranch, is.null))]
        if (length(outBranch) > 0) {
          return(outBranch)
        } else {
          return(NULL)
        }
      }, features, class)

    # Unlist outList so is one list per 'treeLevel'
    treeLevel <- unlist(outList, recursive = FALSE)

    # Increase tree depth
    mDepth <- mDepth + 1

    # Add this level to the tree
    tree[[mDepth]] <- treeLevel
  }
  return(tree)
}


# Wrapper to subset the feature and class set for each split
.wrapBranchHybrid <- function(groups,
                              features,
                              class,
                              fUsed,
                              threshold = 0.95,
                              reuseFeatures = FALSE,
                              oneoffMetric,
                              tryOneoff) {
  # Subset for branch to run split
  gKeep <- rownames(features) %in% groups

  # Remove used features?
  if (reuseFeatures) {
    fSub <- features[gKeep, ]
  } else {
    fSub <-
      features[gKeep, !colnames(features) %in% fUsed, drop = FALSE]
  }

  # Drop levels (class that are no longer in)
  cSub <- droplevels(class[gKeep])

  # If multiple columns in fSub run split, else return null
  if (ncol(fSub) > 1) {
    return(.wrapSplitHybrid(fSub, cSub, threshold, oneoffMetric, tryOneoff))
  } else {
    return(NULL)
  }
}

# Wrapper function to perform split metrics
.wrapSplitHybrid <- function(features,
                             class,
                             threshold = 0.95,
                             oneoffMetric,
                             tryOneoff = TRUE) {
  # Get best one-2-one splits
  ## Use modified f1 or pairwise auc?
  if (tryOneoff) {
    if (oneoffMetric == "modified F1") {
      splitMetric <- .splitMetricModF1
    } else {
      splitMetric <- .splitMetricPairwiseAUC
    }
    splitStats <- .splitMetricRecursive(features,
      class,
      splitMetric = splitMetric
    )
    splitStats <- splitStats[splitStats >= threshold]
    statUsed <- "One-off"
  } else {
    splitStats <- integer(0)
  }


  # If no one-2-one split meets threshold, run semi-supervised clustering
  if (length(splitStats) == 0) {
    splitMetric <- .splitMetricIGpIGd
    splitStats <- .splitMetricRecursive(features,
      class,
      splitMetric = splitMetric
    )[1] # Use top
    statUsed <- "Split"
  }

  # Get split for best gene
  splitList <- lapply(
    names(splitStats),
    .getSplit,
    splitStats,
    features,
    class,
    splitMetric
  )


  # Combine feature rules when same group1 class arises

  if (length(splitList) > 1) {
    group1Vec <- unlist(lapply(splitList, function(X) {
      X$group1Consensus
    }), recursive = FALSE)

    splitList <- lapply(
      unique(group1Vec),
      function(group1, splitList, group1Vec) {
        # Get subset with same group1
        splitListSub <- splitList[group1Vec == group1]

        # Get feature, value, and stat for these
        splitFeature <- unlist(lapply(
          splitListSub,
          function(X) {
            X$featureName
          }
        ))
        splitValue <- unlist(lapply(
          splitListSub,
          function(X) {
            X$value
          }
        ))
        splitStat <- unlist(lapply(
          splitListSub,
          function(X) {
            X$stat
          }
        ))

        # Create a single object and add these
        splitSingle <- splitListSub[[1]]
        splitSingle$featureName <- splitFeature
        splitSingle$value <- splitValue
        splitSingle$stat <- splitStat

        return(splitSingle)
      }, splitList, group1Vec
    )
  }

  names(splitList) <- unlist(lapply(
    splitList,
    function(X) {
      paste(X$featureName, collapse = ";")
    }
  ))

  # Add statUsed
  splitList$statUsed <- statUsed

  return(splitList)
}

# Recursively run split metric on every feature
.splitMetricRecursive <- function(features, class, splitMetric) {
  splitStats <- vapply(colnames(features),
    function(feat, features, class, splitMetric) {
      splitMetric(feat, class, features, rPerf = TRUE)
    }, features, class, splitMetric,
    FUN.VALUE = double(1)
  )
  names(splitStats) <- colnames(features)
  splitStats <- sort(splitStats, decreasing = TRUE)

  return(splitStats)
}

# Run pairwise AUC metirc on single feature
.splitMetricPairwiseAUC <-
  function(feat, class, features, rPerf = FALSE) {
    # Get current feature
    currentFeature <- features[, feat]

    # Get unique classes
    classUnique <- sort(unique(class))

    # Do one-to-all to determine top cluster
    # For each class K1 determine best AUC
    auc1toAll <-
      vapply(classUnique, function(k1, class, currentFeature) {
        # Set value to k1
        classK1 <- as.numeric(class == k1)

        # Get AUC value
        aucK1 <-
          pROC::auc(pROC::roc(
            classK1,
            currentFeature,
            direction = "<",
            quiet = TRUE
          ))

        # Return
        return(aucK1)
      }, class, currentFeature, FUN.VALUE = double(1))

    # Get class with best AUC (Class with generally highest values)
    classMax <- as.character(classUnique[which.max(auc1toAll)])

    # Get other classes
    classRest <- as.character(classUnique[classUnique != classMax])

    # for each second cluster k2
    aucFram <- as.data.frame(do.call(
      rbind,
      lapply(
        classRest,
        function(k2, k1, class, currentFeature) {
          # keep cells in k1 or k2 only
          obsKeep <- class %in% c(k1, k2)
          currentFeatureSubset <- currentFeature[obsKeep]

          # update cluster assignments
          currentClusters <- class[obsKeep]

          # label cells whether they belong to k1 (0 or 1)
          currentLabels <- as.integer(currentClusters == k1)

          # get AUC value for this feat-cluster pair
          rocK2 <-
            pROC::roc(currentLabels,
              currentFeatureSubset,
              direction = "<",
              quiet = TRUE
            )
          aucK2 <- rocK2$auc
          coordK2 <-
            pROC::coords(rocK2, "best", ret = "threshold", transpose = TRUE)[1]

          # Concatenate vectors
          statK2 <- c(threshold = coordK2, auc = aucK2)

          return(statK2)
        }, classMax, class, currentFeature
      )
    ))

    # Get Min Value
    aucMin <- min(aucFram$auc)

    # Get indices where this AUC occurs
    aucMinIndices <- which(aucFram$auc == aucMin)

    # Use maximum value if there are ties
    aucValue <- max(aucFram$threshold)

    # Return performance or value?
    if (rPerf) {
      return(aucMin)
    } else {
      return(aucValue)
    }
  }


# Run modified F1 metric on single feature
.splitMetricModF1 <-
  function(feat, class, features, rPerf = FALSE) {
    # Get number of samples
    len <- length(class)

    # Get Values
    featValues <- features[, feat]

    # Get order of values
    ord <- order(featValues, decreasing = TRUE)

    # Get sorted class and values
    featValuesSort <- featValues[ord]
    classSort <- class[ord]

    # Keep splits of the data where the class changes
    keep <- c(
      classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] &
        featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))],
      FALSE
    )

    # Create data.matrix
    X <- stats::model.matrix(~ 0 + classSort)

    # Get cumulative sums
    sRCounts <- apply(X, 2, cumsum)

    # Keep only values where the class changes
    sRCounts <- sRCounts[keep, , drop = FALSE]
    featValuesKeep <- featValuesSort[keep]

    # Number of each class
    Xsum <- colSums(X)

    # Remove impossible splits (No class has > 50% of there samples on one side)
    sRProbs <- sRCounts %*% diag(Xsum^-1)
    sKeepPossible <-
      rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0

    # Remove anything after a full prob (Doesn't always happen)
    maxCheck <-
      min(c(which(apply(sRProbs, 1, max) == 1), nrow(sRProbs)))
    sKeepCheck <- seq(1, nrow(sRProbs)) %in% seq(1, maxCheck)

    # Combine logical vectors
    sKeep <- sKeepPossible & sKeepCheck

    if (sum(sKeep) > 0) {
      # Remove these if they exist
      sRCounts <- sRCounts[sKeep, , drop = FALSE]
      featValuesKeep <- featValuesKeep[sKeep]

      # Get left counts
      sLCounts <- t(Xsum - t(sRCounts))

      # Calculate the harmonic mean of Sens, Prec, and Worst Alt Sens
      statModF1 <- vapply(seq(nrow(sRCounts)),
        function(i, Xsum, sRCounts, sLCounts) {
          # Right Side
          sRRowSens <-
            sRCounts[i, ] / Xsum # Right sensitivities
          sRRowPrec <-
            sRCounts[i, ] / sum(sRCounts[i, ]) # Right prec
          sRRowF1 <-
            2 * (sRRowSens * sRRowPrec) / (sRRowSens + sRRowPrec)
          sRRowF1[is.nan(sRRowF1)] <- 0 # Get right F1
          bestF1Ind <- which.max(sRRowF1) # Which is the best?
          bestSens <-
            sRRowSens[bestF1Ind] # The corresponding sensitivity
          bestPrec <-
            sRRowPrec[bestF1Ind] # The corresponding precision

          # Left Side
          sLRowSens <-
            sLCounts[i, ] / Xsum # Get left sensitivities
          worstSens <-
            min(sLRowSens[-bestF1Ind]) # Get the worst

          # Get harmonic mean of best sens, best prec, and worst sens
          HMout <- (3 * bestSens * bestPrec * worstSens) /
            (bestSens * bestPrec + bestPrec * worstSens +
              bestSens * worstSens)

          return(HMout)
        }, Xsum, sRCounts, sLCounts,
        FUN.VALUE = double(1)
      )

      # Get Max Value
      ModF1Max <- max(statModF1)

      # Get indices where this value occurs (use minimum row)
      ModF1Index <- which.max(statModF1)

      # Get value at this point
      ValueCeiling <- featValuesKeep[ModF1Index]
      ValueWhich <- which(featValuesSort == ValueCeiling)
      ModF1Value <- mean(c(
        featValuesSort[ValueWhich],
        featValuesSort[ValueWhich + 1]
      ))
    } else {
      ModF1Max <- 0
      ModF1Value <- NA
    }

    if (rPerf) {
      return(ModF1Max)
    } else {
      return(ModF1Value)
    }
  }

# Run Information Gain (probability + density) on a single feature
.splitMetricIGpIGd <- function(feat, class, features, rPerf = FALSE) {
    # Get number of samples
    len <- length(class)

    # Get Values
    featValues <- features[, feat]

    # Get order of values
    ord <- order(featValues, decreasing = TRUE)

    # Get sorted class and values
    featValuesSort <- featValues[ord]
    classSort <- class[ord]

    # Keep splits of the data where the class changes
    keep <- c(
      classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] &
        featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))],
      FALSE
    )

    # Create data.matrix
    X <- stats::model.matrix(~ 0 + classSort)

    # Get cumulative sums
    sRCounts <- apply(X, 2, cumsum)

    # Keep only values where the class changes
    sRCounts <- sRCounts[keep, , drop = FALSE]
    featValuesKeep <- featValuesSort[keep]

    # Number of each class
    Xsum <- colSums(X)

    # Remove impossible splits
    sRProbs <- sRCounts %*% diag(Xsum^-1)
    sKeep <-
      rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0

    if (sum(sKeep) > 0) {
      # Remove these if they exist
      sRCounts <- sRCounts[sKeep, , drop = FALSE]
      featValuesKeep <- featValuesKeep[sKeep]

      # Get left counts
      sLCounts <- t(Xsum - t(sRCounts))

      # Multiply them to get probabilities
      sRProbs <- t(t(sRCounts) %*%
        diag(rowSums(sRCounts)^-1, nrow = nrow(sRCounts)))
      sLProbs <- t(t(sLCounts) %*%
        diag(rowSums(sLCounts)^-1, nrow = nrow(sLCounts)))

      # Multiply them by there log
      sRTrans <- sRProbs * log(sRProbs)
      sRTrans[is.na(sRTrans)] <- 0
      sLTrans <- sLProbs * log(sLProbs)
      sLTrans[is.na(sLTrans)] <- 0

      # Get entropies
      HSR <- -rowSums(sRTrans)
      HSL <- -rowSums(sLTrans)

      # Get overall probabilities and entropy
      nProbs <- colSums(X) / len
      HS <- -sum(nProbs * log(nProbs))

      # Get split proporions
      sProps <- rowSums(sRCounts) / nrow(X)

      # Get information gain (Probability)
      IGprobs <- HS - (sProps * HSR + (1 - sProps) * HSL)
      IGprobs[is.nan(IGprobs)] <- 0
      IGprobsQuantile <- IGprobs / max(IGprobs)
      IGprobsQuantile[is.nan(IGprobsQuantile)] <- 0

      # Get proportions at each split
      classProps <- sRCounts %*% diag(Xsum^-1)
      classSplit <- classProps >= 0.5

      # Initialize information gain density vector
      splitIGdensQuantile <- rep(0, nrow(classSplit))

      # Get unique splits of the data
      classSplitUnique <- unique(classSplit)
      classSplitUnique <-
        classSplitUnique[!rowSums(classSplitUnique) %in%
          c(0, ncol(classSplitUnique)), , drop = FALSE]

      # Get density information gain
      if (nrow(classSplitUnique) > 0) {
        # Get log(determinant of full matrix)
        DET <- .psdet(stats::cov(features))

        # Information gain of every observation
        IGdens <- apply(
          classSplitUnique,
          1,
          .infoGainDensity,
          X,
          features,
          DET
        )

        names(IGdens) <- apply(
          classSplitUnique * 1,
          1,
          function(X) {
            paste(X, collapse = "")
          }
        )

        IGdens[is.nan(IGdens) | IGdens < 0] <- 0
        IGdensQuantile <- IGdens / max(IGdens)
        IGdensQuantile[is.nan(IGdensQuantile)] <- 0

        # Get ID of each class split
        splitsIDs <- apply(
          classSplit * 1,
          1,
          function(x) {
            paste(x, collapse = "")
          }
        )

        # Append information gain density vector
        for (ID in names(IGdens)) {
          splitIGdensQuantile[splitsIDs == ID] <- IGdensQuantile[ID]
        }
      }

      # Add this to the other matrix
      IG <- IGprobsQuantile + splitIGdensQuantile

      # Get IG(probabilty) of maximum value
      IGreturn <- IGprobs[which.max(IG)[1]]

      # Get maximum value
      maxVal <- featValuesKeep[which.max(IG)]
      wMax <- max(which(featValuesSort == maxVal))
      IGvalue <-
        mean(c(featValuesSort[wMax], featValuesSort[wMax + 1]))
    } else {
      IGreturn <- 0
      IGvalue <- NA
    }

    # Report maximum ID or value at maximum IG
    if (rPerf) {
      return(IGreturn)
    } else {
      return(IGvalue)
    }
  }

# Function to find pseudo-determinant
.psdet <- function(x) {
  if (sum(is.na(x)) == 0) {
    svalues <- zapsmall(svd(x)$d)
    sum(log(svalues[svalues > 0]))
  } else {
    0
  }
}

# Function to calculate density information gain
.infoGainDensity <- function(splitVector, X, features, DET) {
  # Get Subsets of the feature matrix
  sRFeat <- features[as.logical(rowSums(X[, splitVector, drop = FALSE])), ,
    drop = FALSE
  ]
  sLFeat <- features[as.logical(rowSums(X[, !splitVector, drop = FALSE])), ,
    drop = FALSE
  ]

  # Get pseudo-determinant of covariance matrices
  DETR <- .psdet(stats::cov(sRFeat))
  DETL <- .psdet(stats::cov(sLFeat))

  # Get relative sizes
  sJ <- nrow(features)
  sJR <- nrow(sRFeat)
  sJL <- nrow(sLFeat)

  IUout <- 0.5 * (DET - (sJR / sJ * DETR + sJL / sJ * DETL))

  return(IUout)
}

# Wrapper function for getting split statistics
.getSplit <-
  function(feat,
           splitStats,
           features,
           class,
           splitMetric) {
    stat <- splitStats[feat]
    splitVal <- splitMetric(feat, class, features, rPerf = FALSE)
    featValues <- features[, feat]

    # Get classes split to one node
    node1Class <- class[featValues > splitVal]

    # Get proportion of each class at each node
    group1Prop <- table(node1Class) / table(class)
    group2Prop <- 1 - group1Prop

    # Get class consensus
    group1Consensus <- names(group1Prop)[group1Prop >= 0.5]
    group2Consensus <- names(group1Prop)[group1Prop < 0.5]

    # Get group samples
    group1 <- rownames(features)[class %in% group1Consensus]
    group2 <- rownames(features)[class %in% group2Consensus]

    # Get class vector
    group1Class <- droplevels(class[class %in% group1Consensus])
    group2Class <- droplevels(class[class %in% group2Consensus])

    return(
      list(
        featureName = feat,
        value = splitVal,
        stat = stat,

        group1 = group1,
        group1Class = group1Class,
        group1Consensus = group1Consensus,
        group1Prop = c(group1Prop),

        group2 = group2,
        group2Class = group2Class,
        group2Consensus = group2Consensus,
        group2Prop = c(group2Prop)
      )
    )
  }

# Function to annotate alternate split of a soley downregulated terminal nodes
.addAlternativeSplit <- function(tree, features, class) {
  # Unlist decsision tree
  DecTree <- unlist(tree, recursive = FALSE)

  # Get leaves
  groupList <- lapply(DecTree, function(split) {
    # Remove directions
    split <-
      split[!names(split) %in% c("statUsed", "fUsed", "dirs")]

    # Get groups
    group1 <- unique(unlist(lapply(
      split,
      function(node) {
        node$group1Consensus
      }
    )))
    group2 <- unique(unlist(lapply(
      split,
      function(node) {
        node$group2Consensus
      }
    )))

    return(list(
      group1 = group1,
      group2 = group2
    ))
  })

  # Get vector of each group
  group1Vec <-
    unique(unlist(lapply(groupList, function(g) {
      g$group1
    })))
  group2Vec <-
    unique(unlist(lapply(groupList, function(g) {
      g$group2
    })))

  # Get group that is never up-regulated
  group2only <- group2Vec[!group2Vec %in% group1Vec]

  # Check whether there are solely downregulated splits
  AltSplitInd <-
    which(unlist(lapply(groupList, function(g, group2only) {
      group2only %in% g$group2
    }, group2only)))

  if (length(AltSplitInd) > 0) {
    AltDec <-
      max(which(unlist(
        lapply(groupList, function(g, group2only) {
          group2only %in% g$group2
        }, group2only)
      )))

    # Get split
    downSplit <- DecTree[[AltDec]]
    downNode <- downSplit[[1]]

    # Get classes to rerun
    branchClasses <- names(downNode$group1Prop)

    # Get samples from these classes and features from this cluster
    sampKeep <- class %in% branchClasses
    featKeep <- !colnames(features) %in% downSplit$fUsed

    # Subset class and features
    cSub <- droplevels(class[sampKeep])
    fSub <- features[sampKeep, featKeep, drop = FALSE]

    # Get best alternative split
    altStats <- do.call(
      rbind,
      lapply(
        colnames(fSub),
        function(feat,
                 splitMetric,
                 features,
                 class,
                 cInt) {
          Val <- splitMetric(feat, cSub, fSub, rPerf = FALSE)

          # Get node1 classes
          node1Class <- class[features[, feat] > Val]

          # Get sensitivity/precision/altSens
          Sens <- sum(node1Class == cInt) / sum(class == cInt)
          Prec <- mean(node1Class == cInt)

          # Get Sensitivity of Alternate Classes
          AltClasses <- unique(class)[unique(class) != cInt]
          AltSizes <- vapply(AltClasses,
            function(cAlt, class) {
              sum(class == cAlt)
            }, class,
            FUN.VALUE = double(1)
          )
          AltWrong <- vapply(AltClasses,
            function(cAlt, node1Class) {
              sum(node1Class == cAlt)
            }, node1Class,
            FUN.VALUE = double(1)
          )
          AltSens <- min(1 - (AltWrong / AltSizes))

          # Get harmonic mean
          HM <- (3 * Sens * Prec * AltSens) /
            (Sens * Prec + Prec * AltSens + Sens * AltSens)
          HM[is.nan(HM)] <- 0

          # Return
          return(data.frame(
            feat = feat,
            val = Val,
            stat = HM,
            stringsAsFactors = FALSE
          ))
        }, .splitMetricModF1, fSub, cSub, group2only
      )
    )
    altStats <-
      altStats[order(altStats$stat, decreasing = TRUE), ]

    # Get alternative splits
    splitStats <- altStats$stat[1]
    names(splitStats) <- altStats$feat[1]
    altSplit <- .getSplit(
      altStats$feat[1],
      splitStats,
      fSub,
      cSub,
      .splitMetricModF1
    )

    # Check that this split out the group2 of interest
    if (length(altSplit$group1Consensus) == 1) {
      # Add it to split
      downSplit[[length(downSplit) + 1]] <- altSplit
      names(downSplit)[length(downSplit)] <- altStats$feat[1]
      downSplit <- downSplit[c(
        which(!names(downSplit) %in% c("statUsed", "fUsed", "dirs")),
        which(names(downSplit) %in% c("statUsed", "fUsed", "dirs"))
      )]

      # Get index of split to add it to
      branchLengths <- unlist(lapply(tree, length))
      branchCum <- cumsum(branchLengths)
      wBranch <- min(which(branchCum >= AltDec))
      if (wBranch == 1) {
        wSplit <- 1
      }
      else {
        wSplit <- which(seq(
          (branchCum[(wBranch - 1)] + 1),
          branchCum[wBranch]
        ) == AltDec)
      }

      # Add it to decision tree
      tree[[wBranch]][[wSplit]] <- downSplit
    } else {
      cat(
        "No non-ambiguous rule to separate",
        group2only,
        "from",
        branchClasses,
        ". No alternative split added."
      )
    }
  } else {
    #  print("No solely down-regulated cluster to add alternative split.")
  }

  return(tree)
}

#' @title Gets cluster estimates using rules generated by
#'  `celda::findMarkersTree`
#' @description Get decisions for a matrix of features. Estimate cell
#'  cluster membership using feature matrix input.
#' @param rules List object. The `rules` element from  `findMarkersTree`
#'  output. Returns NA if cluster estimation was ambiguous.
#' @param features A L(features) by N(samples) numeric matrix.
#' @return A character vector of label predicitions.

getDecisions <- function(rules, features) {
  features <- t(features)
  votes <- apply(features, 1, .predictClass, rules)
  return(votes)
}

# Function to predict class from list of rules
.predictClass <- function(samp, rules) {
  # Initilize possible classes and level
  classes <- names(rules)
  level <- 1

  # Set maximum levele possible to prevent infinity run
  maxLevel <- max(unlist(lapply(rules, function(ruleSet) {
    ruleSet$level
  })))

  while (length(classes) > 1 & level <= maxLevel) {
    # Get possible classes
    clLogical <-
      unlist(lapply(classes, function(cl, rules, level, samp) {
        # Get the rules for this class
        ruleClass <- rules[[cl]]

        # Get the rules for this level
        ruleClass <-
          ruleClass[ruleClass$level == level, , drop = FALSE]

        # Subset class for the features at this level
        ruleClass$sample <- samp[ruleClass$feature]

        # For multiple direction == 1, use one with the top stat
        if (sum(ruleClass$direction == 1) > 1) {
          ruleClass <- ruleClass[order(ruleClass$direction,
            decreasing = TRUE
          ), ]
          ruleClass <- ruleClass[c(
            which.max(ruleClass$stat[ruleClass$direction == 1]),
            which(ruleClass$direction == -1)
          ), , drop = FALSE]
        }

        # Check for followed rules
        ruleClass$check <- ruleClass$sample >= ruleClass$value
        ruleClass$check[ruleClass$direction == -1] <-
          !ruleClass$check[ruleClass$direction == -1]

        # Check that all rules were followed
        ruleFollowed <- mean(ruleClass$check &
          ruleClass$direction == 1) > 0 |
          mean(ruleClass$check) == 1

        return(ruleFollowed)
      }, rules, level, samp))

    # Subset possible classes
    classes <- classes[clLogical]

    # Add level
    level <- level + 1
  }

  # Return if only one class selected
  if (length(classes) == 1) {
    return(classes)
  } else {
    return(NA)
  }
}

# Function to summarize and format tree list output by .generateTreeList
.summarizeTree <- function(tree, features, class) {
  # Format tree into dendrogram object
  dendro <- .convertToDendrogram(tree, class)

  # Map classes to features
  class2features <- .mapClass2features(tree, features, class)

  # Get performance of the tree on training samples
  perfList <-
    .getPerformance(class2features$rules, features, class)

  return(
    list(
      rules = class2features$rules,
      dendro = dendro,
      prediction = perfList$prediction,
      performance = perfList$performance
    )
  )
}

# Function to reformat raw tree ouput to a dendrogram
.convertToDendrogram <- function(tree, class, splitNames = NULL) {
  # Unlist decision tree (one element for each split)
  DecTree <- unlist(tree, recursive = FALSE)

  if (is.null(splitNames)) {
    # Name split by gene and threshold
    splitNames <- lapply(DecTree, function(split) {
      # Remove non-split elements
      dirs <- paste0(split$dirs, collapse = "_")
      split <-
        split[!names(split) %in% c("statUsed", "fUsed", "dirs")]

      # Get set of features and values for each
      featuresplits <- lapply(split, function(node) {
        nodeFeature <- node$featureName
        nodeStrings <- paste(nodeFeature, collapse = ";")
      })

      # Get split directions
      names(featuresplits) <- paste(dirs,
        seq(length(featuresplits)),
        sep = "_"
      )

      return(featuresplits)
    })
    splitNames <- unlist(splitNames)
    names(splitNames) <- sub("1_", "", names(splitNames))
  }
  else {
    names(splitNames) <- seq(length(DecTree[[1]]) - 3)
  }

  # Get Stat Used
  statUsed <- unlist(lapply(DecTree, function(split) {
    split$statUsed
  }))
  statRep <- unlist(lapply(
    DecTree,
    function(split) {
      length(split[!names(split) %in% c("statUsed", "fUsed", "dirs")])
    }
  ))
  statUsed <- unlist(lapply(
    seq(length(statUsed)),
    function(i) {
      rep(statUsed[i], statRep[i])
    }
  ))
  names(statUsed) <- names(splitNames)

  # Create Matrix of results
  mat <-
    matrix(0, nrow = length(DecTree), ncol = length(unique(class)))
  colnames(mat) <- unique(class)
  for (i in seq(1, length(DecTree))) {
    # If only one split than ezpz
    split <- DecTree[[i]]
    split <-
      split[!names(split) %in% c("statUsed", "fUsed", "dirs")]
    if (length(split) == 1) {
      mat[i, split[[1]]$group1Consensus] <- 1
      mat[i, split[[1]]$group2Consensus] <- 2

      # Otherwise we need to assign > 2 splits for different higher groups
    } else {
      # Get classes in group 1
      group1classUnique <- unique(lapply(
        split,
        function(X) {
          X$group1Consensus
        }
      ))
      group1classVec <- unlist(group1classUnique)

      # Get classes always in group 2
      group2classUnique <- unique(unlist(lapply(
        split,
        function(X) {
          X$group2Consensus
        }
      )))
      group2classUnique <-
        group2classUnique[!group2classUnique %in%
          group1classVec]

      # Assign
      for (j in seq(length(group1classUnique))) {
        mat[i, group1classUnique[[j]]] <- j
      }
      mat[i, group2classUnique] <- j + 1
    }
  }

  ## Collapse matrix to get set of direction to include in dendrogram
  matCollapse <- sort(apply(
    mat,
    2,
    function(x) {
      paste(x[x != 0], collapse = "_")
    }
  ))
  matUnique <- unique(matCollapse)

  # Get branchlist
  bList <- c()
  j <- 1
  for (i in seq(max(ncharX(matUnique)))) {
    sLength <- matUnique[ncharX(matUnique) >= i]
    sLength <- unique(subUnderscore(sLength, i))
    for (k in sLength) {
      bList[j] <- k
      j <- j + 1
    }
  }

  # Initialize dendrogram list
  val <- max(ncharX(matUnique)) + 1
  dendro <- list()
  attributes(dendro) <- list(
    members = length(matCollapse),
    classLabels = unique(class),
    height = val,
    midpoint = (length(matCollapse) - 1) / 2,
    label = NULL,
    name = NULL
  )

  for (i in bList) {
    # Add element
    iSplit <- unlist(strsplit(i, "_"))
    iPaste <- paste0(
      "dendro",
      paste(paste0("[[", iSplit, "]]"), collapse = "")
    )
    eval(parse(
      text =
        paste0(iPaste, "<-list()")
    ))

    # Add attributes
    classLabels <- names(matCollapse[subUnderscore(
      matCollapse,
      ncharX(i)
    ) == i])
    members <- length(classLabels)

    # Add height, set to one if leaf
    height <- val - ncharX(i)

    # Check that this isn't a terminal split
    if (members == 1) {
      height <- 1
    }

    # Add labels and stat used
    if (i %in% names(splitNames)) {
      lab <- splitNames[i]
      statUsedI <- statUsed[i]
    } else {
      lab <- NULL
      statUsedI <- NULL
    }
    att <- list(
      members = members,
      classLabels = classLabels,
      edgetext = lab,
      height = height,
      midpoint = (members - 1) / 2,
      label = lab,
      statUsed = statUsedI,
      name = i
    )
    eval(parse(text = paste0("attributes(", iPaste, ") <- att")))

    # Add leaves
    leaves <- matCollapse[matCollapse == i]
    if (length(leaves) > 0) {
      for (l in seq(1, length(leaves))) {
        # Add element
        lPaste <- paste0(iPaste, "[[", l, "]]")
        eval(parse(text = paste0(lPaste, "<-list()")))

        # Add attributes
        members <- 1
        leaf <- names(leaves)[l]
        height <- 0
        att <- list(
          members = members,
          classLabels = leaf,
          height = height,
          label = leaf,
          leaf = TRUE,
          name = i
        )
        eval(parse(text = paste0("attributes(", lPaste, ") <- att")))
      }
    }
  }
  class(dendro) <- "dendrogram"
  return(dendro)
}

# Function to calculate the number of non-underscore characters in a string
ncharX <- function(x) {
  unlist(lapply(strsplit(x, "_"), length))
}

# Function to subset a string of characters seperated by underscores
subUnderscore <- function(x, n) {
  unlist(lapply(
    strsplit(x, "_"),
    function(y) {
      paste(y[seq(n)], collapse = "_")
    }
  ))
}

# Function to calculate performance statistics
.getPerformance <- function(rules, features, class) {
  # Get classification accuracy, balanced accurecy, and per class sensitivity
  ## Get predictions
  votes <- getDecisions(rules, t(features))
  votes[is.na(votes)] <- "MISSING"

  ## Calculate accuracy statistics and per class sensitivity
  class <- as.character(class)
  acc <- mean(votes == as.character(class))
  classCorrect <- vapply(unique(class),
    function(x) {
      sum(votes == x & class == x)
    },
    FUN.VALUE = double(1)
  )
  classCount <- c(table(class))[unique(class)]
  sens <- classCorrect / classCount

  ## Calculate balanced accuracy
  balacc <- mean(sens)

  ## Calculate per class and mean precision
  voteCount <- c(table(votes))[unique(class)]
  prec <- classCorrect / voteCount
  meanPrecision <- mean(prec)

  ## Add performance metrics
  performance <- list(
    accuracy = acc,
    balAcc = balacc,
    meanPrecision = meanPrecision,
    correct = classCorrect,
    sizes = classCount,
    sensitivity = sens,
    precision = prec
  )

  return(list(
    prediction = votes,
    performance = performance
  ))
}

# Create rules of classes and features sequences
.mapClass2features <-
  function(tree, features, class, topLevelMeta = FALSE) {
    # Get class to feature indices
    class2featuresIndices <- do.call(rbind, lapply(
      seq(length(tree)),
      function(i) {
        treeLevel <- tree[[i]]
        c2fsub <- as.data.frame(do.call(rbind, lapply(
          treeLevel,
          function(split) {
            # Keep track of stat used for rule list
            statUsed <- split$statUsed

            # Keep only split information
            split <- split[!names(split) %in%
              c("statUsed", "fUsed", "dirs")]

            # Create data frame of split rules
            edgeFram <-
              do.call(rbind, lapply(split, function(edge) {
                # Create data.frame of groups, split-dirs, feature IDs
                groups <-
                  c(edge$group1Consensus, edge$group2Consensus)
                sdir <- c(
                  rep(1, length(edge$group1Consensus)),
                  rep(-1, length(edge$group2Consensus))
                )
                feat <- edge$featureName
                val <- edge$value
                stat <- edge$stat
                data.frame(
                  class = rep(groups, length(feat)),
                  feature = rep(feat, each = length(groups)),
                  direction = rep(sdir, length(feat)),
                  value = rep(val, each = length(groups)),
                  stat = rep(stat, each = length(groups)),
                  stringsAsFactors = FALSE
                )
              }))

            # Add stat used
            edgeFram$statUsed <- statUsed

            return(edgeFram)
          }
        )))
        c2fsub$level <- i
        return(c2fsub)
      }
    ))
    rownames(class2featuresIndices) <- NULL

    # Generate list of rules for each class
    if (topLevelMeta) {
      orderedClass <- unique(class2featuresIndices[
        class2featuresIndices$direction == 1, "class"
      ])
    }
    else {
      orderedClass <- levels(class)
    }

    rules <-
      lapply(orderedClass, function(cl, class2featuresIndices) {
        class2featuresIndices[
          class2featuresIndices$class == cl,
          colnames(class2featuresIndices) != "class"
        ]
      }, class2featuresIndices)
    names(rules) <- orderedClass

    return(list(rules = rules))
  }

#' @title Plots dendrogram of \emph{findMarkersTree} output
#' @description Generates a dendrogram of the rules and performance
#' (optional) of the decision tree generated by findMarkersTree().
#' @param tree List object. The output of findMarkersTree()
#' @param classLabel A character value. The name of a specific label to draw
#'  the path and rules. If NULL (default), the tree for all clusters is shown.
#' @param addSensPrec Logical. Print training sensitivities and precisions
#'  for each cluster below leaf label? Default is FALSE.
#' @param maxFeaturePrint Numeric value. Maximum number of markers to print
#'  at a given split. Default is 4.
#' @param leafSize Numeric value. Size of text below each leaf. Default is 24.
#' @param boxSize Numeric value. Size of rule labels. Default is 7.
#' @param boxColor Character value. Color of rule labels. Default is black.
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- celda::simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts$counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(sim_counts$counts, cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class, threshold = 1)
#'
#' # Plot dendrogram
#' plotMarkerDendro(DecTree)
#' }
#' @return A ggplot2 object
#' @export
plotMarkerDendro <- function(tree,
                             classLabel = NULL,
                             addSensPrec = FALSE,
                             maxFeaturePrint = 4,
                             leafSize = 10,
                             boxSize = 2,
                             boxColor = "black") {
  # Get necessary elements
  dendro <- tree$dendro

  # Get performance information (training or CV based)
  if (addSensPrec) {
    performance <- tree$performance

    # Create vector of per class performance
    perfVec <- paste0(
      "Sens. ",
      format(round(performance$sensitivity, 2), nsmall = 2),
      "\n Prec. ",
      format(round(performance$precision, 2), nsmall = 2)
    )
    names(perfVec) <- names(performance$sensitivity)
  }

  # Get dendrogram segments
  dendSegs <-
    ggdendro::dendro_data(dendro, type = "rectangle")$segments

  # Get necessary coordinates to add labels to
  # These will have y > 1
  dendSegs <-
    unique(dendSegs[dendSegs$y > 1, c("x", "y", "yend", "xend")])

  # Labeled splits will be vertical (x != xend) or
  # Length 0 (x == xend & y == yend)
  dendSegsAlt <- dendSegs[
    dendSegs$x != dendSegs$xend |
      (dendSegs$x == dendSegs$xend &
        dendSegs$y == dendSegs$yend),
    c("x", "xend", "y")
  ]
  colnames(dendSegsAlt)[1] <- "xalt"

  # Label names will be at nodes, these will
  # Occur at the end of segments
  segs <- as.data.frame(dendextend::get_nodes_xy(dendro))
  colnames(segs) <- c("xend", "yend")

  # Add labels to nodes
  segs$label <-
    gsub(";", "\n", dendextend::get_nodes_attr(dendro, "label"))

  # Subset for max
  segs$label <-
    sapply(segs$label, function(lab, maxFeaturePrint) {
      loc <- gregexpr("\n", lab)[[1]][maxFeaturePrint]
      if (!is.na(loc)) {
        lab <- substr(lab, 1, loc - 1)
      }
      return(lab)
    }, maxFeaturePrint)

  segs$statUsed <- dendextend::get_nodes_attr(dendro, "statUsed")

  # If highlighting a class label, remove non-class specific rules
  if (!is.null(classLabel)) {
    if (!classLabel %in% names(tree$rules)) {
      stop("classLabel not a valid class ID.")
    }
    dendro <- .highlightClassLabel(dendro, classLabel)
    keepLabel <- dendextend::get_nodes_attr(dendro, "keepLabel")
    keepLabel[is.na(keepLabel)] <- FALSE
    segs$label[!keepLabel] <- NA
  }

  # Remove non-labelled nodes &
  # leaf nodes (yend == 0)
  segs <- segs[!is.na(segs$label) & segs$yend != 0, ]

  # Merge to full set of coordinates
  dendSegsLabelled <- merge(dendSegs, segs)

  # Remove duplicated labels
  dendSegsLabelled <- dendSegsLabelled[order(dendSegsLabelled$y,
    decreasing = TRUE
  ), ]
  dendSegsLabelled <- dendSegsLabelled[!duplicated(dendSegsLabelled[
    ,
    c(
      "xend", "x", "yend",
      "label", "statUsed"
    )
  ]), ]

  # Merge with alternative x-coordinates for alternative split
  dendSegsLabelled <- merge(dendSegsLabelled, dendSegsAlt)

  # Order by height and coordinates
  dendSegsLabelled <-
    dendSegsLabelled[order(dendSegsLabelled$x), ]

  # Find information gain splits
  igSplits <- dendSegsLabelled$statUsed == "Split" &
    !duplicated(dendSegsLabelled[, c("xalt", "y")])

  # Set xend for IG splits
  dendSegsLabelled$xend[igSplits] <-
    dendSegsLabelled$xalt[igSplits]

  # Set y for non-IG splits
  dendSegsLabelled$y[!igSplits] <-
    dendSegsLabelled$y[!igSplits] - 0.2

  # Get index of leaf labels
  leafLabels <- dendextend::get_leaves_attr(dendro, "label")

  # Adjust leaf labels if there are metacluster labels
  if (!is.null(tree$metaclusterLabels)) {
    leafLabels <- regmatches(
      leafLabels,
      regexpr(
        pattern = "(?<=\\().*?(?=\\)$)",
        leafLabels, perl = TRUE
      )
    )
  }

  # Add sensitivity and precision measurements
  if (addSensPrec) {
    leafLabels <- paste(leafLabels, perfVec, sep = "\n")
    leafAngle <- 0
    leafHJust <- 0.5
    leafVJust <- -1
  } else {
    leafAngle <- 90
    leafHJust <- 1
    leafVJust <- 0.5
  }

  # Create plot of dendrogram
  suppressMessages(
    dendroP <- ggdendro::ggdendrogram(dendro) +
      ggplot2::geom_label(
        data = dendSegsLabelled,
        ggplot2::aes(
          x = dendSegsLabelled$xend,
          y = dendSegsLabelled$y,
          label = dendSegsLabelled$label
        ),
        size = boxSize,
        label.size = 1,
        fontface = "bold",
        vjust = 1,
        nudge_y = 0.1,
        color = boxColor
      ) +
      ggplot2::theme_bw() +
      ggplot2::scale_x_reverse(
        breaks =
          seq(length(leafLabels)),
        label = leafLabels
      ) +
      ggplot2::scale_y_continuous(expand = c(0, 0)) +
      ggplot2::theme(
        panel.grid.major.y = ggplot2::element_blank(),
        legend.position = "none",
        panel.grid.minor.y = ggplot2::element_blank(),
        panel.grid.minor.x = ggplot2::element_blank(),
        panel.grid.major.x = ggplot2::element_blank(),
        panel.border = ggplot2::element_blank(),
        axis.title = ggplot2::element_blank(),
        axis.ticks = ggplot2::element_blank(),
        axis.text.x = ggplot2::element_text(
          hjust = leafHJust,
          angle = leafAngle,
          size = leafSize,
          family = "Palatino",
          face = "bold",
          vjust = leafVJust
        ),
        axis.text.y = ggplot2::element_blank()
      )
  )

  # Check if need to add metacluster labels
  if (!is.null(tree$metaclusterLabels)) {
    # store metacluster labels to add
    newLabels <- unique(tree$branchPoints$top_level$metacluster)

    # adjust labels for metaclusters of size one
    newLabels <- unlist(lapply(newLabels, function(curMeta) {
      if (substr(curMeta, nchar(curMeta), nchar(curMeta)) == ")") {
        return(gsub(
          pattern = "\\(.*\\)$",
          replacement = "",
          x = curMeta
        ))
      }
      else {
        return(curMeta)
      }
    }))

    # Create table for metacluster labels
    metaclusterText <- dendSegsLabelled[
      dendSegsLabelled$y ==
        max(dendSegsLabelled$y),
      c("xend", "y", "label")
    ]
    metaclusterText$label <- newLabels

    # Add metacluster labels to top of plot
    dendroP <- dendroP +
      ggplot2::geom_text(
        data = metaclusterText,
        ggplot2::aes(
          x = metaclusterText$xend,
          y = metaclusterText$y,
          label = metaclusterText$label,
          fontface = 2
        ),
        angle = 90,
        nudge_y = 0.5,
        family = "Palatino",
        size = leafSize / 3
      )

    # adjust coordinates of plot to show labels
    dendroP <- dendroP + ggplot2::coord_cartesian(
      ylim =
        c(
          0,
          max(dendSegsLabelled$y +
            1)
        )
    )
  }

  # Increase line width slightly for aesthetic purposes
  dendroP$layers[[2]]$aes_params$size <- 1.3

  return(dendroP)
}

# Function to reformat the dendrogram to draw path to a specific class
.highlightClassLabel <- function(dendro, classLabel) {
  # Reorder dendrogram
  flag <- TRUE
  bIndexString <- ""

  # Get branch
  branch <- eval(parse(text = paste0("dendro", bIndexString)))

  while (flag) {
    # Get attributes
    att <- attributes(branch)

    # Get split with the label of interest
    labList <- lapply(branch, function(split) {
      attributes(split)$classLabels
    })
    wSplit <- which(unlist(lapply(
      labList,
      function(vec) {
        classLabel %in% vec
      }
    )))

    # Keep labels for this branch
    branch <- lapply(branch, function(edge) {
      attributes(edge)$keepLabel <- TRUE
      return(edge)
    })

    # Make a dendrogram class again
    class(branch) <- "dendrogram"
    attributes(branch) <- att

    # Add branch to dendro
    eval(parse(text = paste0("dendro", bIndexString, "<- branch")))

    # Create new bIndexString
    bIndexString <- paste0(bIndexString, "[[", wSplit, "]]")

    # Get branch
    branch <- eval(parse(text = paste0("dendro", bIndexString)))

    # Add flag
    flag <- attributes(branch)$members > 1
  }

  return(dendro)
}


#' @title Generate heatmap for a marker decision tree
#' @description Creates heatmap for a specified branch point in a marker tree.
#' @param tree A decision tree returned from \link{findMarkersTree} function.
#' @param counts Numeric matrix. Gene-by-cell counts matrix.
#' @param branchPoint Character. Name of branch point to plot heatmap for.
#' Name should match those in \emph{tree$branchPoints}.
#' @param featureLabels List of feature cluster assignments. Length should
#' be equal to number of rows in counts matrix, and formatting should match
#' that used in \emph{findMarkersTree()}. Required when using clusters
#' of features and not previously provided to \emph{findMarkersTree()}
#' @param topFeatures Integer. Number of genes to plot per marker module.
#' Genes are sorted based on their AUC for their respective cluster.
#' Default is 10.
#' @param silent Logical. Whether to avoid plotting heatmap to screen.
#' Default is FALSE.
#' @return A heatmap visualizing the counts matrix for the cells and genes at
#' the specified branch point.
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class, threshold = 1)
#'
#' # Plot example heatmap
#' plotMarkerHeatmap(DecTree, assay(sim_counts),
#'   branchPoint = "top_level",
#'   featureLabels = paste0("L", celdaModules(cm)))
#' }
#' @export
plotMarkerHeatmap <- function(tree,
           counts,
           branchPoint,
           featureLabels,
           topFeatures = 10,
           silent = FALSE) {
    # get branch point to plot
    branch <- tree$branchPoints[[branchPoint]]

    # check that user entered valid branch point name
    if (is.null(branch)) {
      stop(
        "Invalid branch point.",
        " Branch point name should match one of those in tree$branchPoints."
      )
    }

    # convert counts matrix to matrix (e.g. from dgCMatrix)
    counts <- as.matrix(counts)

    # get marker features
    marker <- unique(branch$feature)

    # add feature labels
    if ("featureLabels" %in% names(tree)) {
      featureLabels <- tree$featureLabels
    }

    # check that feature labels are provided
    if (missing(featureLabels) &
      !("featureLabels" %in% names(tree)) &
      (sum(marker %in% rownames(counts)) != length(marker))) {
      stop("Please provide feature labels, i.e. gene cluster labels")
    }
    else {
      if (missing(featureLabels) &
        !("featureLabels" %in% names(tree)) &
        (sum(marker %in% rownames(counts)) == length(marker))) {
        featureLabels == rownames(counts)
      }
    }

    # make sure feature labels match the table
    if (!all(branch$feature %in% featureLabels)) {
      stop(
        "Provided feature labels don't match those in the tree.",
        " Please check the feature names in the tree's rules' table."
      )
    }

    # if top-level in metaclusters tree
    if (branchPoint == "top_level") {
      # get unique metaclusters
      metaclusters <- unique(branch$metacluster)

      # list which will contain final set of genes for heatmap
      whichFeatures <- c()

      # loop over unique metaclusters
      for (meta in metaclusters) {
        # subset table
        curMeta <- branch[branch$metacluster == meta, ]

        # if we have gene-level info in the tree
        if ("gene" %in% names(branch)) {
          # sort by gene AUC score
          curMeta <-
            curMeta[order(curMeta$geneAUC, decreasing = TRUE), ]

          # get genes
          genes <- unique(curMeta$gene)

          # keep top N features
          genes <- utils::head(genes, topFeatures)

          # get gene indices
          markerGenes <- which(rownames(counts) %in% genes)

          # get features with non-zero variance to avoid clustering error
          markerGenes <- .removeZeroVariance(
            counts,
            cells = which(
              tree$metaclusterLabels %in%
                unique(curMeta$metacluster)
            ),
            markers = markerGenes
          )

          # add to list of features
          whichFeatures <- c(whichFeatures, markerGenes)
        }
        else {
          # current markers
          curMarker <- unique(curMeta$feature)

          # get marker gene indices
          markerGenes <- which(featureLabels %in% curMarker)

          # get features with non-zero variance to avoid error
          markerGenes <- .removeZeroVariance(
            counts,
            cells = which(
              tree$metaclusterLabels %in%
                unique(curMeta$metacluster)
            ),
            markers = markerGenes
          )

          # add to list of features
          whichFeatures <- c(whichFeatures, markerGenes)
        }
      }

      # order the metaclusters by size
      colOrder <- data.frame(
        groupName = names(sort(
          table(tree$metaclusterLabels),
          decreasing = TRUE
        )),
        groupIndex = seq_along(unique(tree$metaclusterLabels))
      )

      # order the markers for metaclusters
      allMarkers <- stats::setNames(as.list(colOrder$groupName),
          colOrder$groupName)
      allMarkers <- lapply(allMarkers, function(x) {
        unique(branch[branch$metacluster == x, "feature"])
      })
      rowOrder <- data.frame(
        groupName = unlist(allMarkers),
        groupIndex = seq_along(unlist(allMarkers))
      )
      toRemove <-
        which(!rowOrder$groupName %in% featureLabels[whichFeatures])
      if (length(toRemove) > 0) {
        rowOrder <- rowOrder[-toRemove, ]
      }

      # sort cells according to metacluster size
      x <- tree$metaclusterLabels
      y <- colOrder$groupName
      sortedCells <- seq(ncol(counts))[order(match(x, y))]

      # create heatmap with only the markers
      return(
        plotHeatmap(
          counts = counts,
          z = tree$metaclusterLabels,
          y = featureLabels,
          featureIx = whichFeatures,
          cellIx = sortedCells,
          showNamesFeature = TRUE,
          main = "Top-level",
          silent = silent,
          treeheightFeature = 0,
          colGroupOrder = colOrder,
          rowGroupOrder = rowOrder,
          treeheightCell = 0
        )
      )
    }

    # if balanced split
    if (branch$statUsed[1] == "Split") {
      # keep entries for balanced split only (in case of alt. split)
      split <- branch$feature[1]
      branch <- branch[branch$feature == split, ]

      # get up-regulated and down-regulated classes
      upClasses <- unique(branch[branch$direction == 1, "class"])
      downClasses <-
        unique(branch[branch$direction == (-1), "class"])

      # re-order cells to keep up and down separate on the heatmap
      reorderedCells <- c(
        (which(tree$classLabels %in% upClasses)
        [order(tree$classLabels[tree$classLabels %in% upClasses])]),
        (which(tree$classLabels %in% downClasses)
        [order(tree$classLabels[tree$classLabels %in% downClasses])])
      )

      # cell annotation based on split
      cellAnno <-
        data.frame(
          split = rep("Down-regulated", ncol(counts)),
          stringsAsFactors = FALSE
        )
      cellAnno$split[which(tree$classLabels %in% upClasses)] <-
        "Up-regulated"
      rownames(cellAnno) <- colnames(counts)

      # if we have gene-level info in the tree
      if (("gene" %in% names(branch))) {
        # get genes
        genes <- unique(branch$gene)

        # keep top N features
        genes <- utils::head(genes, topFeatures)

        # get gene indices
        whichFeatures <- which(rownames(counts) %in% genes)

        # get features with non-zero variance to avoid error
        whichFeatures <- .removeZeroVariance(counts,
          cells = which(tree$classLabels %in%
            unique(branch$class)),
          markers = whichFeatures
        )

        # create heatmap with only the split feature and split classes
        return(
          plotHeatmap(
            counts = counts,
            z = tree$classLabels,
            y = featureLabels,
            featureIx = whichFeatures,
            cellIx = reorderedCells,
            clusterCell = FALSE,
            showNamesFeature = TRUE,
            main = branchPoint,
            silent = silent,
            treeheightFeature = 0,
            treeheightCell = 0,
            annotationCell = cellAnno
          )
        )
      }
      else {
        # get features with non-zero variance to avoid error
        whichFeatures <-
          .removeZeroVariance(
            counts,
            cells = reorderedCells,
            markers = which(featureLabels ==
              branch$feature[1])
          )

        # create heatmap with only the split feature and split classes
        return(
          plotHeatmap(
            counts = counts,
            z = tree$classLabels,
            y = featureLabels,
            featureIx = whichFeatures,
            cellIx = reorderedCells,
            clusterCell = FALSE,
            showNamesFeature = TRUE,
            main = branchPoint,
            silent = silent,
            treeheightFeature = 0,
            treeheightCell = 0,
            annotationCell = cellAnno
          )
        )
      }
    }

    # if one-off split
    if (branch$statUsed[1] == "One-off") {
      # get unique classes
      classes <- unique(branch$class)

      # list which will contain final set of genes for heatmap
      whichFeatures <- c()

      # loop over unique classes
      for (class in classes) {
        # subset table
        curClass <-
          branch[branch$class == class & branch$direction == 1, ]

        # if we have gene-level info in the tree
        if (("gene" %in% names(branch))) {
          # get genes
          genes <- unique(curClass$gene)

          # keep top N features
          genes <- utils::head(genes, topFeatures)

          # get gene indices
          markerGenes <- which(rownames(counts) %in% genes)

          # get features with non-zero variance to avoid error
          markerGenes <- .removeZeroVariance(
            counts,
            cells = which(tree$classLabels %in%
              unique(curClass$class)),
            markers = markerGenes
          )

          # add to list of features
          whichFeatures <- c(whichFeatures, markerGenes)
        }
        else {
          # get features with non-zero variance to avoid error
          markerGenes <- .removeZeroVariance(
            counts,
            cells = which(tree$classLabels %in%
              unique(curClass$class)),
            markers = which(featureLabels %in%
              unique(curClass$feature))
          )

          # add to list of features
          whichFeatures <- c(whichFeatures, markerGenes)
        }
      }

      # order the clusters such that up-regulated come first
      colOrder <- data.frame(
        groupName = unique(branch[
          order(branch$direction, decreasing = TRUE),
          "class"
        ]),
        groupIndex = seq_along(unique(branch$class))
      )

      # order the markers for clusters
      allMarkers <- stats::setNames(as.list(colOrder$groupName),
          colOrder$groupName)
      allMarkers <- lapply(allMarkers, function(x) {
        unique(branch[branch$class == x & branch$direction == 1, "feature"])
      })
      rowOrder <- data.frame(
        groupName = unlist(allMarkers),
        groupIndex = seq_along(unlist(allMarkers))
      )
      toRemove <-
        which(!rowOrder$groupName %in% featureLabels[whichFeatures])
      if (length(toRemove) > 0) {
        rowOrder <- rowOrder[-toRemove, ]
      }

      # sort cells according to metacluster size
      x <-
        tree$classLabels # [tree$classLabels %in% unique(branch$class)]
      y <- colOrder$groupName
      sortedCells <- seq(ncol(counts))[order(match(x, y))]
      sortedCells <-
        sortedCells[seq(sum(tree$classLabels %in% classes))]

      # create heatmap with only the split features and split classes
      return(
        plotHeatmap(
          counts = counts,
          z = tree$classLabels,
          y = featureLabels,
          featureIx = whichFeatures,
          cellIx = sortedCells,
          showNamesFeature = TRUE,
          main = branchPoint,
          silent = silent,
          treeheightFeature = 0,
          colGroupOrder = colOrder,
          rowGroupOrder = rowOrder,
          treeheightCell = 0
        )
      )
    }
  }

# helper function to identify zero-variance genes in a counts matrix
.removeZeroVariance <- function(counts, cells, markers) {
  # subset counts matrix
  counts <- counts[, cells]

  # scale rows
  counts <- t(scale(t(counts)))

  # get indices of genes which have NA
  zeroVarianceGenes <- which(!stats::complete.cases(counts))

  # find overlap between zero-variance genes and marker genes
  zeroVarianceMarkers <- intersect(zeroVarianceGenes, markers)

  # return indices of marker genes without zero-variance
  if (length(zeroVarianceMarkers) > 0) {
    return(markers[-which(markers %in% zeroVarianceMarkers)])
  } else {
    return(markers)
  }
}

Try the celda package in your browser

Any scripts or data that you put into this service are public.

celda documentation built on Nov. 8, 2020, 8:24 p.m.