#' Cross-validation for NMF
#'
#' @description
#' NMF cross-validation for rank determination against the angle between bipartite factorizations
#'
#' @details
#' nnmf.cv splits the dataset into non-overlapping halves by either row or column and runs NMF on both of these halves at a number of ranks of k.
#' Factors in the NMF model are matched one-to-one by cosine similarity, and the mean angle between both models is calculated as the mean of the angles between matched factors.
#' The rank of k with the minimum angle is the rank at which latent space is most robust.
#'
#' This cross-validation procedure can be run multiple times on permutations of the dataset, but if only a single run is requested (n.start = 1), a "smart split" is applied (semi non-random) which maximizes signal redundancy between bipartite partitions of the dataset.
#' Generally, a single run with smart.split is sufficient for determination of optimal rank k and captures most of the information that would be learned from multiple starts on entirely random partitions.
#' The scNMF::canyon.plot function is useful for visualizing the results of nnmf.cv to determine optimal rank k or for optimizing the cross-validation procedure.
#' After determining the optimal rank, scNMF::nnmf may be run at the optimal rank.
#'
#' Subsetting: For large datasets, nnmf.cv may often be run on a subset of the data if signal redundancy is sufficient. However, if there is insufficient signal redundancy, nnmf.cv may not reveal any "canyon" or local minima.
#'
#' @param A A matrix to be factorized (i.e. result from average.expression) or a Seurat object with cluster centers in a dimensional reduction slot. If sparse, will be coerced to dense format. If the entire data should be used from the Seurat object, specify reduction = NULL.
#' @param k Array of integer ranks (default seq(from = 5, to = 20, by = 2))
#' @param n.starts Number of random starts, each run at all given values of k for a unique set of indices (default 1)
#' @param verbose 0 = no tracking, 1 = progress bars for each n.starts, 2 = message for each factorization, 3 = all the details for each factorization
#' @param byrow Bipartition by rows rather than columns (default TRUE)
#' @param smart.split Boolean, whether to use smart.split to determine indices if n.starts = 1. Smart split maximizes the signal redundancy between the bipartition of the dataset to achieve optimal cross-validation results. Generally, a single run of smart.split is as informative as multiple runs on random subsets. TRUE by default.
#' @param seed Random seed for reproducibility.
#' @param rel.tol Stop criterion for each NNMF run, defined as the relative tolerance between two successive iterations: |e2-e1|/avg(e1,e2). (default 1e-3, although 1e-2 may be useful for faster course-grained preliminary analysis of large datasets, small datasets may benefit from a higher tolerance such as 1e-4)
#' @param n.threads Number of threads/CPUs to use (default is 0, for all cores)
#' @param return.models Boolean, should W and H matrices be returned for each run (default FALSE). W and H matrices can take up significant memory in large cross-validation experiments.
#' @param reduction If Seurat object is provided, specify a reduction to use feature loadings (i.e. cluster centers), otherwise specify NULL to use the entire counts matrix from the default assay ("dclus" by default).
#' @param dist.method "cosine" (default) or "bhjattacharyya" (alternative) for computing distances between clusters and a similarity graph. In exceptionally sparse datasets, bhjattacharyya distance can outperform cosine distance.
#' @param max.iter Maximum number of alternating NNLS solves (default 1000)
#' @param smart.split.block.size Integer, default 200. Smaller is faster, larger achieves better separation of redundant features. Block size gives how many features to run bipartite matching on at a time, the rate limiting component is the bipartite graph solver. When block size is small, the similarity of matched features will be lower. When block size is large, similarity of matched features will be higher and cross-validation result may be better.
#' @param trace An integer specifying a multiple of ANLS NNMF iterations at which MSE error should be calculated and checked for convergence against rel.tol. To check error every iteration, specify 1. To avoid checking error entirely, specify trace > max.iter (default is 5, and is generally an efficient and effective value). For particularly sparse or heterogenous datasets which require hundreds of ANNLS iterations, setting a trace of 10 or 20 may speed up the calculation slightly.
#' @return A list with cross-validation info, most easily visualized by running scNMF::canyon.plot on the result. List includes a tall format dataframe of factor angles (factor.angle with columns "k", "factor.angle", "seed"), a tall format dataframe of model angles (model.angle with columns "k", "model.angle", "seed"), if models were requested a list of models and matched factors within a list of starts
nnmf.cv <- function(A,
byrow = TRUE,
k = seq(from = 5, to = 20, by = 2),
max.iter = 1000,
rel.tol = 1e-3,
n.threads = 0,
verbose = 1,
trace = 5,
seed = 123,
n.starts = 1,
return.models = FALSE,
smart.split = TRUE,
smart.split.block.size = 200,
reduction = "dclus",
dist.method = "cosine") {
require(qlcMatrix)
require(RcppHungarian)
require(wordspace)
if (class(A)[1] == "Seurat") {
if (!is.null(reduction)) {
if (reduction %in% Reductions(A) == FALSE) stop("Input reduction = ", reduction, " but this name is not in Reductions(A) = ", str(Reductions(A)))
if (verbose > 0) message("Extracting feature loadings from reduction = ", reduction, "\n")
A <- A@reductions[[reduction]]@feature.loadings
} else {
if (verbose > 0) message("Extracting counts matrix from default assay of Seurat object, and converting to dense matrix format")
A <- as.matrix(GetAssayData(A))
}
}
if (class(A)[1] == "dgCMatrix") A <- as.matrix(A)
if (!is.matrix(A)) stop("Input A does not satisfy is.matrix() after attempting to convert from Seurat and/or dgCMatrix")
if (n.threads < 0) stop("Specify 0 or a positive integer for n.threads")
if (is.null(k)) stop("Specify a positive integer value or array of integer values for k")
if (rel.tol > 0.1) warning("rel.tol is greater than 0.1, results may be unstable")
if (trace < 1) stop("trace must be a positive integer")
if (length(seed) > 1) stop("specify a single integer seed (specified seed has length(seed) > 1)")
if (is.logical(verbose)) verbose <- as.integer(verbose)
if (n.starts < 1) stop("n.starts must be a positive integer")
inner.rel.tol <- 1e-6
inner.max.iter <- 100
verbose.nnmf <- verbose - 2
if (verbose.nnmf < 0) verbose.nnmf <- 0
summaries <- list()
if (is.null(seed)) stop("A seed must be given")
set.seed(seed)
# if only one start is requested, apply smart split to maximize redundancy between both halves of the data
# loop through all features by block.size and separate by cosine similarity, where the most similar features are assigned to opposite subsetting groupings
indices <- c()
if (n.starts == 1 && smart.split == TRUE) {
if (verbose > 0) message("\nCalculating smart split indices:")
ifelse(byrow == FALSE, lenA <- ncol(A), lenA <- nrow(A))
if (smart.split.block.size > lenA) smart.split.block.size <- lenA
num.iter <- floor(lenA / smart.split.block.size)
block.size <- floor(lenA / num.iter)
if (verbose > 0) pb <- txtProgressBar(char = "=", style = 3, max = num.iter, width = 50)
for (m in 1:num.iter) {
ind.begin <- (m - 1) * block.size + 1
ind.end <- m * block.size
ifelse(byrow == TRUE,
A.sub <- t(A[ind.begin:ind.end,]),
A.sub <- A[, ind.begin:ind.end]
)
# for a given set of genes, find the maximally similar off diagonal pairings
cos.dist <- 1 - as.matrix(qlcMatrix::cosSparse(A.sub))
diag(cos.dist) <- 1
matched <- HungarianSolver(cos.dist)$pairs
ind <- c()
# assign one gene from each off diagonal pairing to the indices vector
for (i in 1:nrow(matched))
if (which(matched[, 1] == i) < which(matched[, 2] == i)) ind <- c(ind, i)
indices <- c(indices, colnames(A.sub)[ind])
if (verbose > 0) setTxtProgressBar(pb = pb, value = m)
}
if (byrow == TRUE) {
len.names <- 1:length(rownames(A))
names(len.names) <- rownames(A)
} else {
len.names <- 1:length(colnames(A))
names(len.names) <- colnames(A)
}
indices <- as.vector(len.names[indices])
}
for (start in 1:n.starts) {
# say some nice things if requested to do so
if (verbose > 0 && n.starts > 1) {
message("\n\nSTART #", start, "/", n.starts, ": Running NMF for ", length(k), " values of k")
if (verbose > 2) message("________________________________________________________________________________________")
} else if (verbose > 0) {
message("\n\nRunning NMF for ", length(k), " values of k")
}
if (verbose == 1) pb <- txtProgressBar(char = "=", style = 3, max = length(k), width = 50)
# if n.starts is greater than 1, generate indices of length ncol(A) or nrow(A) given a boolean value for byrow and subset A correspondingly
set.seed(seed + start)
if (n.starts > 1 || smart.split == FALSE) {
ifelse(byrow == FALSE,
indices <- sample(1:ncol(A), floor(ncol(A) / 2)),
indices <- sample(1:nrow(A), floor(nrow(A) / 2))
)
}
if (byrow == FALSE) {
A1 <- A[, indices]
A2 <- A[, - indices]
} else {
A1 <- A[indices,]
A2 <- A[-indices,]
}
# summary will hold all metadata (and models if requested) generated from each run of nnmf
summary <- list()
for (i in 1:length(k)) {
if (verbose > 2) message("\n***********************************************")
if (verbose > 1) message("\n Running NMF for k = ", k[i])
if (verbose > 1) message(" ...factorizing first subset")
res1 <- c_nnmf(
A1,
as.integer(k[i]),
as.integer(max.iter),
as.double(rel.tol),
as.integer(n.threads),
as.integer(verbose.nnmf),
as.integer(inner.max.iter),
as.double(inner.rel.tol),
as.integer(trace)
)
if (verbose > 2) message("\n")
if (verbose > 1) message(" ...factorizing second subset")
res2 <- c_nnmf(
A2,
as.integer(k[i]),
as.integer(max.iter),
as.double(rel.tol),
as.integer(n.threads),
as.integer(verbose.nnmf),
as.integer(inner.max.iter),
as.double(inner.rel.tol),
as.integer(trace)
)
if (verbose > 2) message("\n")
if (verbose > 1) message(" Matching factors and calculating mean standard error")
# compute cosine distance between all factors in W or H (whichever shares variables in common based on row-wise or column-wise splitting)
if (dist.method == "cosine") {
ifelse(byrow == FALSE,
cosDists <- 1 - as.matrix(cosSparse(res1$W, res2$W)),
cosDists <- 1 - as.matrix(cosSparse(t(res1$H), t(res2$H)))
)
} else {
# bhjattacharyya distance
ifelse(byrow == FALSE,
cosDists <- wordspace::dist.matrix(sqrt(t(res1$W)), sqrt(t(res2$W)), byrow = TRUE, method = "euclidean"),
cosDists <- wordspace::dist.matrix(sqrt(res1$H), sqrt(res2$H), byrow = TRUE, method = "euclidean")
)
cosDists <- max(cosDists) - cosDists
}
# find the best possible bipartite matching for all factors using the Hungarian algorithm
matched <- HungarianSolver(cosDists)$pairs
colnames(matched) <- c("model1", "model2")
# get the angles between matched factors
angles <- apply(matched, 1, function(x) cosDists[x[1], x[2]])
if (verbose > 1) message(" ...mean angle between both factorizations: ", round(mean(angles), 5))
if (verbose == 1) setTxtProgressBar(pb = pb, value = i)
ifelse(return.models == TRUE,
summary[[i]] <- list("k" = k[i], "factor.angles" = angles, "model.angle" = mean(angles), "seed" = seed + start, "model1" = res1, "model2" = res2, "matched.factors" = matched),
summary[[i]] <- list("k" = k[i], "factor.angles" = angles, "model.angle" = mean(angles), "seed" = seed + start)
)
}
summaries[[paste0("start", start)]] <- summary
}
model.angles <- factor.angles <- list()
for (i in 1:length(summaries)) {
model.angles[[i]] <- data.frame(t(sapply(summaries[[i]], function(x) c("k" = x$k, "model.angle" = x$model.angle, "seed" = x$seed))))
factor.angles[[i]] <- data.frame(t(rbind(
"k" = unlist(lapply(summaries[[i]], function(x) rep(x$k, x$k))),
"factor.angle" = unlist(lapply(summaries[[i]], function(x) x$factor.angles)),
"seed" = unlist(sapply(summaries[[i]], function(x) rep(x$seed, x$k)))
)))
}
summaries[["model.angles"]] <- do.call(rbind, model.angles)
summaries[["factor.angles"]] <- do.call(rbind, factor.angles)
ifelse(return.models == TRUE, return(summaries), return(list("model.angles" = summaries$model.angles, "factor.angles" = summaries$factor.angles)))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.