#' Trains scRNA-seq data via tunable logistic regression model
#'
#' @param train_data Seurat object, SummarizedExperiment object or expression matrix for training
#' @param train_cell_type The cell types/clusters in the training data set
#' @param test_data Seurat object, SummarizedExperiment/SingleCellExperiment object or expression matrix for testing later
#' @param train_genes Genes to use for training. If not provided, it will try to pick from all genes in the training dataset as per default glmnet.
#' @param standardize a logical value specifying whether or not to standardize the train matrix
#' @param nfolds integer specifying bin for cross validation. Use all samples if doing LOOCV.
#' @param a tunable regularization parameter. 0 = ridge (L2), 1 = LASSO (L1), in between = Elastic-net
#' @param l.min logical. Choose between lambda.min or lambda.1se
#' @param multinomial logical. Choose between family = 'binomial' or 'multinomial'.
#' @param nParallel integer specifying number of cores for parallelization.
#' @param ... other functions pass to glmnet
#' @return Generates a trained model for predicting cell types for scRNAseq data
#' @examples
#' fit <- trainScSimilarity(trainData, clusters, testData)
#' @import glmnet
#' @import Matrix
#' @import doMC
#' @import SummarizedExperiment
#' @import SingleCellExperiment
#' @export
#'
trainScSimilarity <- function(train_data, train_cell_type, test_data, train_genes = NULL,
standardize = TRUE, nfolds = 10, a = 0.9, l.min = FALSE, multinomial = FALSE,
nParallel = parallel::detectCores(), ...) {
fit <- list()
require(Matrix)
standardizeSparse <- function(A) {
A@x <- A@x/rep.int(Matrix::colSums(A), diff(A@p))
return(A)
}
getPopulationOffset = function(y) {
if (!is.factor(y))
y = factor(y)
if (length(levels(y)) != 2)
stop("y must be a two-level factor")
off = sum(y == levels(y)[2])/length(y)
off = log(off/(1 - off))
return(rep(off, length(y)))
}
if (nParallel > 1) {
doMC::registerDoMC(cores = nParallel)
PARALLEL = TRUE
} else {
PARALLEL = FALSE
}
if (is.null(train_genes)) {
if (class(train_data) %in% c("SingleCellExperiment", "SummarizedExperiment")) {
require(SummarizedExperiment)
require(SingleCellExperiment)
train_dat <- assay(train_data)
} else if (class(train_data) == "Seurat") {
train_dat <- tryCatch(train_data@data, error = function(e) {
tryCatch(GetAssayData(object = train_data), error = function(e) {
warning(paste0("are you sure this is a seurat v3 object?"))
return(NULL)
})
})
} else {
train_dat <- as.matrix(train_data)
}
if (class(test_data) %in% c("SingleCellExperiment", "SummarizedExperiment")) {
require(SummarizedExperiment)
require(SingleCellExperiment)
test_dat <- assay(test_data)
} else if (class(test_data) == "Seurat") {
test_dat <- tryCatch(test_data@data, error = function(e) {
tryCatch(GetAssayData(object = test_data), error = function(e) {
warning(paste0("are you sure this is a seurat v3 object?"))
return(NULL)
})
})
} else {
test_dat <- as.matrix(test_data)
}
cat(paste0("No pre-defined genes provided. Filtering ", crayon::red(dim(train_dat)[1]), " genes for training"), sep = "\n")
genes.intersect <- intersect(row.names(test_dat), row.names(train_dat))
train_dat <- train_dat[which(row.names(train_dat) %in% genes.intersect), ]
cat("Transposing matrix", sep = "\n")
if (class(train_dat) == "matrix") {
train_dat <- Matrix::Matrix(train_dat, sparse = TRUE)
}
train_dat <- Matrix::t(train_dat)
Zero_col <- which(Matrix::colSums(train_dat) == 0)
duplicated_col <- which(duplicated(colnames(train_dat)) == TRUE)
if (length(c(Zero_col, duplicated_col)) != 0) {
cat(paste0("Removing ", crayon::red(length(c(Zero_col, duplicated_col))), " genes with no variance"), sep = "\n")
train_dat <- train_dat[, -c(Zero_col, duplicated_col)]
}
cat(paste0("Submitting ", crayon::red(ncol(train_dat)), " intersecting genes to glmnet for selecting predictors"), sep = "\n")
if (standardize == TRUE) {
cat("Standardizing training dataset", sep = "\n")
train_dat <- standardizeSparse(train_dat)
}
labels = levels(train_cell_type)
if (length(labels) == 0) {
train_cell_type <- factor(train_cell_type)
labels <- levels(train_cell_type)
}
if (multinomial == FALSE) {
cat(paste0(crayon::magenta("Training model with family = "), crayon::yellow("binomial")), sep = "\n")
for (label in labels) {
cat(crayon::green(paste0("Training model for ", crayon::red(label))), sep = "\n")
celltype = factor(train_cell_type == label)
fit[[label]] = tryCatch(glmnet::cv.glmnet(train_dat, celltype, offset = getPopulationOffset(celltype), family = "binomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, ...), error = function(e)
{
tryCatch(glmnet::cv.glmnet(train_dat, celltype, offset = getPopulationOffset(celltype), family = "binomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, lambda = exp(seq(-10, -3, length.out = 100)), ...), error = function(e2)
{
warning(paste0("Could not train model for variable ", label))
return(NULL)
})
})
}
cat("Extracting best gene features...", sep = "\n")
for (i in 1:length(fit)) {
if (l.min) {
cat("Choosing the best model but with the caveat that may be too complex, may be slightly overfitted", sep = "\n")
fit_out <- as.matrix(coef(fit[[i]], s = fit[[i]]$lambda.min))
} else {
cat("Choosing the simplest model that has comparable error to the best model given the uncertainty", sep = "\n")
fit_out <- as.matrix(coef(fit[[i]], s = fit[[i]]$lambda.1se))
}
fit_out <- as.data.frame(fit_out)
fit_out <- fit_out[fit_out[, 1] != 0, , drop = FALSE]
cat(sep = "\n")
cat(paste0("Best genes for ", names(fit)[i]), sep = "\n")
print(fit_out, sep = "\n")
}
return(fit)
} else {
cat(paste0(crayon::magenta("Training model with family = "), crayon::yellow("multinomial")), sep = "\n")
fit <- tryCatch(glmnet::cv.glmnet(train_dat, train_cell_type, family = "multinomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, ...), error = function(e)
{
tryCatch(glmnet::cv.glmnet(train_dat, train_cell_type, family = "multinomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, lambda = exp(seq(-10, -3, length.out = 100)), ...), error = function(e2)
{
warning(paste0("Could not train model with family = multinomial. Try with multinomial = FALSE"))
return(NULL)
})
})
cat("Extracting best gene features", sep = "\n")
for (i in 1:length(fit$glmnet.fit$beta)) {
if (l.min) {
fit_out <- as.matrix(coef(fit, s = fit$lambda.min)[[i]])
} else {
fit_out <- as.matrix(coef(fit, s = fit$lambda.1se)[[i]])
}
fit_out <- as.data.frame(fit_out)
fit_out <- fit_out[fit_out[, 1] != 0, , drop = FALSE]
cat(sep = "\n")
cat(paste0("Best genes for ", levels(train_cell_type)[i]), sep = "\n")
print(fit_out, sep = "\n")
}
return(fit)
}
} else {
if (class(train_data) %in% c("SingleCellExperiment", "SummarizedExperiment")) {
require(SummarizedExperiment)
require(SingleCellExperiment)
all_genes <- elementMetadata(train_data)[, 1]
train_dat <- assay(train_data[which(all_genes %in%
train_genes)])
} else if (class(train_data) == "Seurat") {
all_genes <- tryCatch(all_genes <- rownames(train_data@data),
error = function(e) {
all_genes <- tryCatch(all_genes <- rownames(GetAssayData(object = train_data)),
error = function(e) {
warning(paste0("are you sure this is a seurat v3 object?"))
return(NULL)
})
})
train_dat <- tryCatch(train_dat <- train_data@data[which(all_genes %in%
train_genes), ], error = function(e) {
train_dat <- tryCatch(train_dat <- GetAssayData(object = train_data)[which(all_genes %in%
train_genes), ], error = function(e) {
warning(paste0("are you sure this is a seurat v3 object?"))
return(NULL)
})
})
} else {
all_genes <- rownames(train_data)
train_dat <- train_data[which(all_genes %in% train_genes), ]
train_dat <- as.matrix(train_dat)
}
cat(paste0("provided ", dim(train_dat)[1], " genes for training model"), sep = "\n")
if (class(test_data) %in% c("SingleCellExperiment", "SummarizedExperiment")) {
require(SummarizedExperiment)
require(SingleCellExperiment)
test_dat <- assay(test_data)
} else if (class(test_data) == "Seurat") {
test_dat <- tryCatch(test_data@data, error = function(e) {
tryCatch(GetAssayData(object = test_data), error = function(e) {
warning(paste0("are you sure this is a seurat v3 object?"))
return(NULL)
})
})
} else {
test_dat <- as.matrix(test_data)
}
genes.intersect <- intersect(row.names(test_dat), row.names(train_dat))
train_dat <- train_dat[which(row.names(train_dat) %in% genes.intersect), ]
cat(paste0("trimed to ", dim(train_dat)[1], " intersecting genes for training model"), sep = "\n")
cat("Transposing matrix", sep = "\n")
if (class(train_dat) == "matrix") {
train_dat <- Matrix::Matrix(train_dat, sparse = TRUE)
}
train_dat <- Matrix::t(train_dat)
Zero_col <- which(Matrix::colSums(train_dat) == 0)
duplicated_col <- which(duplicated(colnames(train_dat)) == TRUE)
if (length(c(Zero_col, duplicated_col)) != 0) {
cat(paste0("Removing ", crayon::red(length(c(Zero_col, duplicated_col))), " genes with no variance"), sep = "\n")
train_dat <- train_dat[, -c(Zero_col, duplicated_col)]
}
if (standardize == TRUE) {
cat("Standardizing training dataset", sep = "\n")
train_dat <- standardizeSparse(train_dat)
}
labels = levels(train_cell_type)
if (length(labels) == 0) {
train_cell_type <- factor(train_cell_type)
labels <- levels(train_cell_type)
}
if (multinomial == FALSE) {
cat(paste0(crayon::magenta("Training model with family = "), crayon::yellow("binomial")), sep = "\n")
for (label in labels) {
cat(crayon::green(paste0("Training model for ", crayon::red(label))), sep = "\n")
celltype = factor(train_cell_type == label)
fit[[label]] = tryCatch(glmnet::cv.glmnet(train_dat, celltype, family = "binomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, ...), error = function(e)
{
tryCatch(glmnet::cv.glmnet(train_dat, celltype, family = "binomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, lambda = exp(seq(-10, -3, length.out = 100)), ...), error = function(e2)
{
warning(paste0("Could not train model for variable ", label))
return(NULL)
})
})
}
cat("Extracting best gene features", sep = "\n")
for (i in 1:length(fit)) {
if (l.min) {
cat("Choosing the best model but with the caveat that may be too complex, may be slightly overfitted", sep = "\n")
fit_out <- as.matrix(coef(fit[[i]], s = fit[[i]]$lambda.min))
} else {
cat("Choosing the simplest model that has comparable error to the best model given the uncertainty", sep = "\n")
fit_out <- as.matrix(coef(fit[[i]], s = fit[[i]]$lambda.1se))
}
fit_out <- as.data.frame(fit_out)
fit_out <- fit_out[fit_out[, 1] != 0, , drop = FALSE]
cat(sep = "\n")
cat(paste0("Best genes for ", names(fit)[i]), sep = "\n")
print(fit_out, sep = "\n")
}
return(fit)
} else {
cat(paste0(crayon::magenta("Training model with family = "), crayon::yellow("multinomial")), sep = "\n")
fit <- tryCatch(glmnet::cv.glmnet(train_dat, train_cell_type, family = "multinomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, ...), error = function(e)
{
tryCatch(glmnet::cv.glmnet(train_dat, train_cell_type, family = "multinomial", alpha = a, nfolds = nfolds, type.measure = "class", parallel = PARALLEL, lambda = exp(seq(-10, -3, length.out = 100)), ...), error = function(e2)
{
warning(paste0("Could not train model with family = multinomial. Try with multinomial = FALSE"))
return(NULL)
})
})
cat("Extracting best gene features", sep = "\n")
for (i in 1:length(fit$glmnet.fit$beta)) {
if (l.min) {
fit_out <- as.matrix(coef(fit, s = fit$lambda.min)[[i]])
} else {
fit_out <- as.matrix(coef(fit, s = fit$lambda.1se)[[i]])
}
fit_out <- as.data.frame(fit_out)
fit_out <- fit_out[fit_out[, 1] != 0, , drop = FALSE]
cat(sep = "\n")
cat(paste0("Best genes for ", levels(train_cell_type)[i]), sep = "\n")
print(fit_out, sep = "\n")
}
return(fit)
}
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.