R/nnls_project.R

Defines functions nnls.project

Documented in nnls.project

#' Get for mapping weights (H) for a set of samples to a latent model (W.model) by NNLS
#'
#' @description
#' This function will minimize mean((samples - W.model %*% H)^2)
#'
#' @param samples Samples to be projected onto W.model
#' @param W.model NMF factor model (genes x factors) for mapping to cells
#' @param n.threads Number of threads/CPUs to use. Default to 0 (all cores)
#' @param block.size How many samples to project concurrently (default is NULL)
#' @param scd.tol Tolerance of the sequential coordinate NNLS fit (default 1e-8)
#' @param scd.max.iter Maximum iterations of sequential coordinate descent within the NNLS fit (default 500)
#' @param verbose Boolean
#' @return A matrix of mapping weights (H). The returned matrix will be coerced to sparse format if the input was sparse.
nnls.project <- function(samples, W.model, n.threads = 0, block.size = NULL, verbose = TRUE, scd.tol = 1e-8, scd.max.iter = 500) {
  require(Matrix)
  if (class(W.model)[1] == "dgCMatrix") W.model <- as.matrix(W.model)
  if (n.threads < 0) n.threads <- 0
  if (nrow(W.model) != nrow(samples)) stop("Number of rows in W.model are not equal to number of rows in samples matrix")
  if (is.null(block.size) || block.size > ncol(samples)) {
    if (class(samples)[1] == "dgCMatrix") samples <- as.matrix(samples)
    H <- Matrix(c_project(samples, W.model, as.integer(n.threads), as.integer(scd.max.iter), as.double(scd.tol))$H, sparse = TRUE)
  } else {
    H <- list()
    n.iters <- ceiling(ncol(samples) / block.size)
    if (verbose > 0) pb <- txtProgressBar(char = "=", style = 3, max = n.iters, width = 50)
    for (i in 1:n.iters) {
      ifelse(i != n.iters,
        samples.iter <- samples[, ((i - 1) * block.size + 1):(i * block.size)],
        samples.iter <- samples[, ((i - 1) * block.size + 1):ncol(samples)])
      if (class(samples.iter)[1] == "dgCMatrix") samples.iter <- as.matrix(samples.iter)
      H[[i]] <- Matrix(c_project(samples.iter, W.model, as.integer(n.threads), as.integer(500), as.double(1e-8))$H, sparse = TRUE)
      if (verbose > 0) setTxtProgressBar(pb = pb, value = i)
    }
    H <- do.call(cbind, H)
  }
  colnames(H) <- colnames(samples)
  if (class(samples)[1] == "dgCMatrix") H <- Matrix(H, sparse = TRUE)
  return(H)
}
zdebruine/scNMF documentation built on Jan. 1, 2021, 1:50 p.m.