R/getDecisions.R

Defines functions .predictClass getDecisions

Documented in getDecisions

#' @title Gets cluster estimates using rules generated by
#'  \link{findMarkersTree}
#' @description Get decisions for a matrix of features. Estimate cell
#'  cluster membership using feature matrix input.
#' @param rules List object. The \code{rules} element from
#'  \code{findMarkersTree} output. Returns NA if cluster estimation was
#'  ambiguous.
#' @param features A L (features) by N (samples) numeric matrix.
#' @return A character vector of label predicitions.
#' @examples
#' \dontrun{
#' library(M3DExampleData)
#' counts <- M3DExampleData::Mmus_example_list$data
#' # Subset 500 genes for fast clustering
#' counts <- as.matrix(counts[seq(1501, 2000), ])
#' # Cluster genes and samples each into 10 modules
#' sce <- celda_CG(counts = counts, L = 10, K = 5, verbose = FALSE)
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(sce)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(sce)
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features,
#'   class,
#'   oneoffMetric = "modified F1",
#'   threshold = 1,
#'   consecutiveOneoff = FALSE)
#' # Get sample estimates in training data
#' getDecisions(DecTree$rules, features)
#' }
#' @export
getDecisions <- function(rules, features) {
  features <- t(features)
  votes <- apply(features, 1, .predictClass, rules)
  return(votes)
}

# Function to predict class from list of rules
.predictClass <- function(samp, rules) {

  # Initilize possible classes and level
  classes <- names(rules)
  level <- 1

  # Set maximum levele possible to prevent infinity run
  maxLevel <- max(unlist(lapply(rules, function(ruleSet) {
    ruleSet$level
  })))

  while (length(classes) > 1 & level <= maxLevel) {

    # Get possible classes
    clLogical <- unlist(lapply(classes, function(cl, rules, level, samp) {

      # Get the rules for this class
      ruleClass <- rules[[cl]]

      # Get the rules for this level
      ruleClass <- ruleClass[ruleClass$level == level, , drop = FALSE]

      # Subset class for the features at this level
      ruleClass$sample <- samp[ruleClass$feature]

      # For multiple direction == 1, use one with the top stat
      if (sum(ruleClass$direction == 1) > 1) {
        ruleClass <- ruleClass[order(
          ruleClass$direction,
          decreasing = TRUE
        ), ]
        ruleClass <- ruleClass[c(
          which.max(
            ruleClass$stat[ruleClass$direction == 1]
          ),
          which(ruleClass$direction == -1)
        ), , drop = FALSE]
      }

      # Check for followed rules
      ruleClass$check <- ruleClass$sample >= ruleClass$value
      ruleClass$check[ruleClass$direction == -1] <- !ruleClass$check[
        ruleClass$direction == -1
      ]

      # Check that all rules were followed
      ruleFollowed <- mean(
        ruleClass$check & ruleClass$direction == 1
      ) > 0 |
        mean(ruleClass$check) == 1

      return(ruleFollowed)
    }, rules, level, samp))

    # Subset possible classes
    classes <- classes[clLogical]

    # Add level
    level <- level + 1
  }

  # Return if only one class selected
  if (length(classes) == 1) {
    return(classes)
  } else {
    return(NA)
  }
}

Try the celda package in your browser

Any scripts or data that you put into this service are public.

celda documentation built on Nov. 8, 2020, 8:24 p.m.