#' @include Utils.R
#' @include CellTypist.R
#' @import mlr3
#' @import mlr3learners
#' @import ranger
#' @import ggplot2
#'
utils::globalVariables(
names = c('Freq', 'Passed', 'Reference', 'Second', 'Top', 'TopLabel', 'booster', 'type', 'value', 'splitrule'),
package = 'RIRA',
add = TRUE
)
#' @title Creates a binary classifier to classify cells within a Seurat object
#'
#' @description Creates a binary classifier to classify cells
#' @param training_matrix A counts or data slot provided by TrainModelsFromSeurat
#' @param celltype The celltype (provided by TrainModelsFromSeurat) used as classifier's positive prediction
#' @param hyperparameter_tuning logical that determines whether or not hyperparameter tuning should be performed.
#' @param learner The mlr3 learner that should be used. Currently fixed to "classif.ranger" if hyperparameter tuning is FALSE. Otherwise, "classif.xgboost" and "classif.ranger" are supported.
#' @param inner_resampling The resampling strategy that is used for hyperparameter optimization. Holdout ("hout" or "holdout") and cross validation ("cv" or "cross-validation") are supported.
#' @param outer_resampling The resampling strategy that is used to determine overfitting. Holdout ("hout" or "holdout") and cross validation ("cv" or "cross-validation") are supported.
#' @param inner_folds The number of folds to be used for inner_resampling if cross-valdiation is performed.
#' @param outer_folds The number of folds to be used for outer_resampling if cross-valdiation is performed.
#' @param inner_ratio The ratio of training to testing data to be used for inner_resampling if holdout resampling is performed.
#' @param outer_ratio The ratio of training to testing data to be used for inner_resampling if holdout resampling is performed.
#' @param n_models The number of models to be trained during hyperparameter tuning. The model with the highest accuracy will be selected and returned.
#' @param n_cores If non-null, this number of workers will be used with future::plan
#' @export
TrainModel <- function(training_matrix, celltype, hyperparameter_tuning = F, learner = "classif.ranger", inner_resampling = "cv", outer_resampling = "cv", inner_folds = 4, inner_ratio = 0.8, outer_folds = 3, outer_ratio = 0.8, n_models = 20, n_cores = NULL){
set.seed(GetSeed())
if (!is.null(n_cores)){
future::plan("multisession", workers = n_cores)
}
#Fix gene names to conform with Seurat Object processing
colnames(training_matrix) <- gsub(names(training_matrix), pattern = "-", replacement = ".")
#create classification matrix
classification.data <- training_matrix
#Trim the celltype_column containing all of the ground truth celltypes
#Note, this column is named "celltype" and is not the same as the column named by the passed celltype argument, which changes according to the celltype being predicted
classification.data <- subset(classification.data, select = -celltype)
#trim the binarized label/celltype/truth column (should be the last column)
classification.data <- classification.data[,1:(ncol(classification.data)-1)]
celltype_binary <- training_matrix[,ncol(training_matrix)]
classification.data[,"celltype_binary"] <- as.factor(celltype_binary)
colnames(classification.data) <- make.names(colnames(classification.data),unique = T)
task <- mlr3::TaskClassif$new(classification.data, id = "CellTypeBinaryClassifier", target = "celltype_binary")
if (!hyperparameter_tuning){
#Use a ranger random forest tree with default parameters and holdout (80% train, 20% test) resampling
#classification.data has column number = number_of_genes + 1, so the mtry argument uses the full gene matrix
learner <- mlr3::lrn("classif.ranger", importance = "permutation", num.trees=500, mtry=(ncol(classification.data)-1), predict_type = "prob")
set_threads(learner)
train_set <- sample(task$nrow, 0.8 * task$nrow)
test_set <- setdiff(seq_len(task$nrow), train_set)
model <- learner$train(task, row_ids = train_set)$predict(task, row_ids=test_set)
confusion <- caret::confusionMatrix(factor(model$response), factor(model$truth))
return(list(model = learner, metrics = confusion))
} else {
#Set model-independent values for the autotuner
measure <- msr("classif.ce")
terminator <- mlr3verse::trm("evals", n_evals = n_models)
#Define a tuning space 25% as large as the number of models
#In the case of sensitive hyperparameters, resolution = 5 allows for a low/medium-low/medium/medium-high/high type parameter space
tuner <- mlr3verse::tnr("grid_search", resolution = 5)
#Define resampling method used for hyperparameter tuning
if (inner_resampling == "cv" || inner_resampling == "cross-validation"){
inner_resample <- rsmp("cv", folds = inner_folds)
} else if (inner_resampling == "hout" || inner_resampling == "holdout"){
inner_resample <- rsmp("holdout", ratio = inner_ratio)
} else {
stop("Unknown inner_resampling method provided. Please select one of cross-validation, cv, hout, or holdout")
}
#Define resampling method used to determine overfitting
if (outer_resampling == "cv" || outer_resampling == "cross-validation"){
outer_resample <- rsmp("cv", folds = outer_folds)
} else if (outer_resampling == "hout" || outer_resampling == "holdout"){
outer_resample <- rsmp("holdout", ratio = outer_ratio)
} else {
stop("Unknown outer_resampling method provided. Please select one of cross-validation, cv, hout, or holdout")
}
#Set learner and define parameter space
if (learner == "classif.ranger"){
#Define learner
learner <- mlr3::lrn("classif.ranger", importance = "permutation", predict_type = "prob")
#Define Ranger Hyperparameter Space (RandomBotv2)
tune_ps <- mlr3verse::ps(
num.trees = mlr3verse::p_int(lower = 10, upper = 2000),
sample.fraction = mlr3verse::p_dbl(lower = 0.1, upper = 1),
respect.unordered.factors = mlr3verse::p_fct(levels = c("ignore", "order", "partition")),
min.node.size = mlr3verse::p_int(lower = 1, upper = 100),
splitrule = mlr3verse::p_fct(levels = c("gini", "extratrees")),
num.random.splits = mlr3verse::p_int(lower = 1, upper = 100, depends = splitrule == "extratrees")
)
} else if (learner == "classif.xgboost"){
#Update task
task <- mlr3::TaskClassif$new(classification.data, id = "CellTypeBinaryClassifier", target = "celltype_binary")
#Define learner
learner <- mlr3::lrn("classif.xgboost", predict_type = "prob")
#Define XGBoost model's Hyperparameter Space (RandomBotv2)
tune_ps <- mlr3verse::ps(
booster = mlr3verse::p_fct(levels = c("gblinear", "gbtree", "dart")),
nrounds = mlr3verse::p_int(lower = 2, upper = 8, trafo = function(x) as.integer(round(exp(x)))),
eta = mlr3verse::p_dbl(lower = -4, upper = 0, trafo = function(x) 10^x),
gamma = mlr3verse::p_dbl(lower = -5, upper = 1, trafo = function(x) 10^x),
lambda = mlr3verse::p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
alpha = mlr3verse::p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
subsample = mlr3verse::p_dbl(lower = 0.1, upper = 1),
max_depth = mlr3verse::p_int(lower = 1, upper = 15),
min_child_weight = mlr3verse::p_dbl(lower = -1, upper = 0, trafo = function(x) 10^x),
colsample_bytree = mlr3verse::p_dbl(lower = 0.1, upper = 1),
colsample_bylevel = mlr3verse::p_dbl(lower = 0.1, upper = 1),
rate_drop = mlr3verse::p_int(lower = 0, upper = 1, depends = booster == 'dart'),
skip_drop = mlr3verse::p_int(lower = 0, upper = 1, depends = booster == 'dart')
)
}
}
#Define the autotuner using the parameter spaces and conditions defined above
at <- mlr3tuning::AutoTuner$new(
learner = learner,
resampling = inner_resample,
measure = measure,
search_space = tune_ps,
terminator = terminator,
tuner = tuner
)
#Train the initial model to optimize hyperparameters
at$train(task)
#instantiate outer resampling
outer_resample$instantiate(task)
#score model using outer resampling strategy & output full resampling to be mined for metrics
full_resampling <- resample(task, at, outer_resample, store_models = TRUE)
return(list(model = at, metrics = full_resampling))
}
#' @title Wrapper function for TrainModel to train a suite of binary classifiers for each cell type present in the data
#'
#' @description Wrapper function for TrainModel to train a suite of binary classifiers for each cell type present in the data
#' @param seuratObj The Seurat Object to be updated
#' @param celltype_column The metadata column containing the celltypes. One classifier will be created for each celltype present in this column.
#' @param assay SeuratObj assay containing the desired count matrix/metadata
#' @param slot Slot containing the count data. Should be restricted to counts, data, or scale.data.
#' @param output_dir The directory in which models, metrics, and training data will be saved.
#' @param hyperparameter_tuning Logical that determines whether or not hyperparameter tuning should be performed.
#' @param learner The mlr3 learner that should be used. Currently fixed to "classif.ranger" if hyperparameter tuning is FALSE. Otherwise, "classif.xgboost" and "classif.ranger" are supported.
#' @param inner_resampling The resampling strategy that is used for hyperparameter optimization. Holdout ("hout" or "holdout") and cross validation ("cv" or "cross-validation") are supported.
#' @param outer_resampling The resampling strategy that is used to determine overfitting. Holdout ("hout" or "holdout") and cross validation ("cv" or "cross-validation") are supported.
#' @param inner_folds The number of folds to be used for inner_resampling if cross-valdiation is performed.
#' @param outer_folds The number of folds to be used for outer_resampling if cross-valdiation is performed.
#' @param inner_ratio The ratio of training to testing data to be used for inner_resampling if holdout resampling is performed.
#' @param outer_ratio The ratio of training to testing data to be used for inner_resampling if holdout resampling is performed.
#' @param n_models The number of models to be trained during hyperparameter tuning. The model with the highest accuracy will be selected and returned.
#' @param n_cores If non-null, this number of workers will be used with future::plan
#' @param gene_list If non-null, the input count matrix will be subset to these features
#' @param gene_exclusion_list If non-null, the input count matrix will be subset to drop these features
#' @param verbose Whether or not to print the metrics data for each model after training.
#' @param min_cells_per_class If provided, any classes (and corresponding cells) with fewer than this many cells will be dropped from the training data
#' @export
TrainModelsFromSeurat <- function(seuratObj, celltype_column, assay = "RNA", slot = "data", output_dir = "./classifiers", hyperparameter_tuning = F, learner = "classif.ranger", inner_resampling = "cv", outer_resampling = "cv", inner_folds = 4, inner_ratio = 0.8, outer_folds = 3, outer_ratio = 0.8, n_models = 20, n_cores = NULL, gene_list = NULL, gene_exclusion_list = NULL, verbose = TRUE, min_cells_per_class = 20){
if (methods::missingArg(celltype_column)) {
stop('Must provide the celltype_column argument')
}
if (!celltype_column %in% names(seuratObj@meta.data)) {
stop(paste0('The column: ', celltype_column, ' is not present in the seurat object'))
}
if (endsWith(output_dir, "/")){
output_dir <- gsub(output_dir, pattern = "/$", replacement = "")
}
if (!is.null(min_cells_per_class) && min_cells_per_class > 0) {
seuratObj <- .DropLowCountClasses(seuratObj, celltype_column, min_cells_per_class)
}
#Read the raw data from a seurat object and parse into an mlr3-compatible labeled matrix
raw_data_matrix <- attr(x = seuratObj@assays[[assay]], which = slot)
if (!all(is.null(gene_list))) {
gene_list <- ExpandGeneList(gene_list)
if (!all(gene_list %in% rownames(raw_data_matrix))) {
missing <- gene_list[!gene_list %in% rownames(raw_data_matrix)]
stop(paste0('All features in gene_list must be present in the Seurat object features. Missing: ', paste0(missing, collapse = ',')))
}
raw_data_matrix <- raw_data_matrix[gene_list,]
}
if (!all(is.null(gene_exclusion_list))) {
gene_exclusion_list <- ExpandGeneList(gene_exclusion_list)
raw_data_matrix <- raw_data_matrix[!rownames(raw_data_matrix) %in% gene_exclusion_list,]
}
training_matrix <- as.data.frame(Matrix::t(as.matrix(raw_data_matrix)))
training_matrix$celltype <- seuratObj@meta.data[,celltype_column]
celltypes <- unique(training_matrix[,"celltype"])
#Create output directories
#Trained binary classifiers will be saved to /models
#Parseable metrics .rds files will be saved to /metrics
#Training data .rds files will be saved to /training_data
for (fn in c(output_dir, paste0(output_dir,"/models"), paste0(output_dir,"/metrics"), paste0(output_dir, "/training_data"))) {
if (!dir.exists(fn)){
dir.create(fn)
}
}
print(paste0("Output directory: ", output_dir))
# TODO: rather than use cell types directly as file names, we should use make.names() or something to ensure they are valid and sane (i.e. 'CD8+ T cells')
# TODO: rather than saving one RDS per classifier, would it make more sense to save a list of cellType -> classifier? This bundles everything into one file on disk?
#Iterate over celltypes and train a binary classifier for each celltype present in the celltype_column
print(paste0("Total cell types: ", length(celltypes)))
for (celltype in celltypes){
print(paste0("Training: ", celltype))
temp_training_matrix <- training_matrix
temp_training_matrix[,celltype] <- ifelse(training_matrix[,"celltype"]==celltype,1,0)
if (sum(duplicated(names(temp_training_matrix))) > 0) {
stop(paste0('Found duplicate names in the input matrix: ', paste0(names(temp_training_matrix)[duplicated(names(temp_training_matrix))], collapse = ',')))
}
names(temp_training_matrix) <- make.names(names(temp_training_matrix), unique = T)
temp_model <- TrainModel(temp_training_matrix, celltype, hyperparameter_tuning = hyperparameter_tuning, learner = learner, inner_resampling = inner_resampling, outer_resampling = outer_resampling, inner_folds = inner_folds,inner_ratio = inner_ratio, outer_folds = outer_folds, outer_ratio = outer_ratio, n_models = n_models, n_cores = n_cores)
#trim the "celltype" column (leaving just the labeled varible celltype column as truth) and save the training matrix
temp_training_matrix <- subset(temp_training_matrix, select = -celltype)
saveRDS(temp_training_matrix, file = paste0(output_dir, "/training_data/", make.names(celltype), "_Training_Matrix.rds"))
if (hyperparameter_tuning){
#Save the trained model to the output directory
saveRDS(temp_model$model, file = paste0(output_dir, "/models/", make.names(celltype), "_BinaryClassifier.rds"))
saveRDS(temp_model$metrics, file = paste0(output_dir, "/metrics/", make.names(celltype), "_BinaryClassifier_Resampled.rds"))
} else if (!hyperparameter_tuning){
#Save the trained model to the output directory
saveRDS(temp_model$model, file = paste0(output_dir, "/models/", make.names(celltype), "_BinaryClassifier.rds"))
saveRDS(temp_model$metrics, file = paste0(output_dir, "/metrics/", make.names(celltype), "_BinaryClassifier_ConfusionMatrix.rds"))
}
}
print("All models trained!")
#Parse and print accuracy metrics from metrics rds files
if (verbose){
#Grab model names from model directory
data <- NULL
metrics_files <- list.files(paste0(output_dir, "/metrics"))
for (metrics_file in metrics_files){
toAdd <- .ParseMetricsFile(paste0(output_dir,"/metrics/",metrics_file))
if (all(is.null(data))) {
data <- toAdd
} else {
data <- rbind(data, toAdd)
}
}
print(ggplot(data, aes(x = type, y = value, fill = type)) +
geom_bar(stat = 'identity', color = 'black') +
labs(x = 'Type', y = 'Accuracy') +
egg::theme_presentation(base_size = 12) +
ggtitle("Accuracy") +
ylim(0,1) +
theme(
legend.position = 'none'
)
)
}
}
.ParseMetricsFile <- function(metrics_file){
data <- NULL
if (grepl("ConfusionMatrix", metrics_file)){
confusion <- readRDS(metrics_file)
label <- gsub(basename(metrics_file), pattern = '_BinaryClassifier.rds', replacement = '')
dat <- as.data.frame(confusion$table)
dat2 <- as.data.frame(prop.table(confusion$table))
P1 <- ggplot2::ggplot(dat2, aes(x = Prediction, y = Reference, z = Freq, fill = Freq)) +
geom_tile() +
geom_text(data = dat, mapping = aes(x = Prediction, y = Reference, label = Freq), inherit.aes = FALSE, size = 14) +
egg::theme_presentation(base_size = 12) +
ggtitle(label) +
scale_fill_gradient2() +
theme(
legend.position = 'none'
)
P2 <- ggplot2::ggplot(dat2, aes(x = Prediction, y = Reference, z = Freq, fill = Freq)) +
geom_tile() +
geom_text(data = dat2, mapping = aes(x = Prediction, y = Reference, label = Freq), inherit.aes = FALSE, size = 14) +
egg::theme_presentation(base_size = 12) +
ggtitle(label) +
scale_fill_gradient2() +
theme(
legend.position = 'none'
)
print(P1 + P2)
toAdd <- data.frame(type = label, metric = "Accuracy", value = confusion$overall["Accuracy"])
if (all(is.null(data))) {
data <- toAdd
} else {
data <- rbind(data, toAdd)
}
} else if (grepl("Resampled", metrics_file)){
full_resampling <- readRDS(metrics_file)
#print(mlr3tuning::extract_inner_tuning_results(full_resampling))
print(full_resampling$score())
print(full_resampling$aggregate())
} else{
stop("Unexpected metrics file filename. Please ensure the metrics files are generated by TrainModelsFromSeurat or adhere to the naming convention if manually generated.")
}
return(data)
}
#' @title Applies a trained binary classifier to get per-cell probabilities.
#'
#' @description Applies a trained model to get per-cell probabilities.
#' @param seuratObj The Seurat Object to be updated
#' @param model Either the full filepath to a model RDS file, or the name of a built-in model.
#' @param fieldToClass A list mapping the target field name in the seurat object to the classifier level. The latter is either numeric, or the string label. For example: list('CD4_T' = 1, 'CD8_T' = 2))
#' @param batchSize To conserve memory, data will be chunked into batches of at most this many cells
#' @param assayName The assay holding gene expression data
#' @export
ScoreCellsWithSavedModel <- function(seuratObj, model, fieldToClass, batchSize = 20000, assayName = 'RNA') {
classifier <- .ResolveModel(modelFile = model)
#De-sparse and transpose seuratObj normalized data & make names unique
gene_expression_matrix <- Matrix::t(Seurat::GetAssayData(seuratObj, assay = assayName, slot = "data"))
# NOTE: makeNames() will convert hyphen to period, and also prefix genes with numeric starts, like 7SK.2 -> X7SK.2
colnames(gene_expression_matrix) <- make.names(colnames(gene_expression_matrix))
# TODO: Is there is more universal way to get the model to report its features?
# See: https://mlr3.mlr-org.com/reference/LearnerClassif.html
modelFeats <- NULL
if ('importance' %in% classifier$properties) {
modelFeats <- names(classifier$importance())
} else if (!is.null(classifier$model) && grepl(x = classifier$model$TypeDetail, pattern = 'logistic regression')) {
modelFeats <- colnames(classifier$model$W)
toRemove <- names(classifier$param_set$levels)
modelFeats <- modelFeats[!modelFeats %in% toRemove]
}
if (!all(is.null(modelFeats))){
missing <- modelFeats[!modelFeats %in% colnames(gene_expression_matrix)]
if (length(missing) > 0) {
stop(paste0('The following features are used in the model and missing from the input: ', paste0(sort(missing), collapse = ',')))
}
# Subset input data to match model:
toDrop <- !(colnames(gene_expression_matrix) %in% modelFeats)
if (sum(toDrop) > 0) {
print(paste0('Dropping features not present in model: ', sum(toDrop), ' of ', ncol(gene_expression_matrix)))
gene_expression_matrix <- gene_expression_matrix[,!toDrop, drop = FALSE]
}
} else {
warning(paste0('Unable to infer features from model, type: ', classifier$model$TypeDetail))
}
print(paste0('Features shared between gene matrix and model: ', ncol(gene_expression_matrix)))
nBatches <- ifelse(is.na(batchSize), yes = 1, no = ceiling(nrow(gene_expression_matrix) / batchSize))
probability_vectors <- list()
for (batchIdx in 1:nBatches){
start <- 1 + ((batchIdx-1) * batchSize)
end <- min((batchIdx * batchSize), nrow(gene_expression_matrix))
print(paste0("Iteration ", batchIdx, " of ", nBatches, ", (", start, "-", end, ")"))
#columns are named '0','1', so the first column is '0' and the second column is '1'
dat <- stats::predict(classifier, newdata = data.frame(gene_expression_matrix[start:end,]), predict_type = 'prob')
for (fieldName in names(fieldToClass)) {
idx <- fieldToClass[[fieldName]]
if (is.na(as.numeric(idx))) {
# Try to resolve from model:
if (idx %in% levels(classifier$model$ClassNames)) {
idx <- which(levels(classifier$model$ClassNames) == idx)
} else {
stop(paste0('Unknown class: ', idx))
}
} else {
idx <- as.numeric(idx)
}
if (batchIdx == 1) {
probability_vectors[[fieldName]] <- dat[,idx]
} else {
probability_vectors[[fieldName]] <- c(probability_vectors[[fieldName]], dat[,idx])
}
}
}
for (fieldName in names(probability_vectors)) {
probability_vector <- probability_vectors[[fieldName]]
if (length(probability_vector) != ncol(seuratObj)) {
stop(paste0('Error calculating probability_vector. Length was: ', length(probability_vector)))
}
#append probabilities to seurat metadata
seuratObj@meta.data[[fieldName]] <- probability_vector
if (length(names(seuratObj@reductions)) > 0) {
print(Seurat::FeaturePlot(seuratObj, features = fieldName))
}
}
return(seuratObj)
}
#' @title Applies trained models to get celltype probabilities
#'
#' @description Applies trained models to get celltype probabilities.
#' @param seuratObj The Seurat Object to be updated
#' @param models A named vector of models, where the values are the modelName (for built-in model), or filePath to an RDS file. The names of the vector should be the cell-type label for cells scored positive by the classifier.
#' @param fieldName The name of the metadata column to store the result
#' @param batchSize To conserve memory, data will be chunked into batches of at most this many cells
#' @param assayName The assay holding gene expression data
#' @param minimum_probability The minimum probability for a confident cell type assignment
#' @param minimum_delta The minimum difference in probabilities necessary to call one celltype over another.
#' @export
PredictCellTypeProbability <- function(seuratObj, models, fieldName = 'RIRA_Consensus', batchSize = 20000, assayName= "RNA", minimum_probability = 0.5, minimum_delta = 0.25){
fieldNames <- c()
for (modelName in names(models)) {
print(paste0('Scoring with model: ', modelName))
probColName <- paste0(modelName, '_probability')
fieldNames <- c(fieldNames, probColName)
fieldToClass <- list()
fieldToClass[[probColName]] <- 2 # Assume binary classifier for now
seuratObj <- ScoreCellsWithSavedModel(seuratObj, model = models[[modelName]], fieldToClass = fieldToClass, batchSize = batchSize, assayName = assayName)
}
seuratObj <- AssignCellType(seuratObj, probabilityColumns = fieldNames, fieldName = fieldName, minimum_probability = minimum_probability, minimum_delta = minimum_delta)
return(seuratObj)
}
#' @title Assigns celltype label based on probabilities in the metadata
#'
#' @description Assigns celltype label based on probabilities in the metadata
#' @param seuratObj The Seurat Object to be updated
#' @param probabilityColumns The set of columns containing probabilities for each classifier to include
#' @param fieldName The name of the metadata column to store the result
#' @param minimum_probability The minimum probability for a confident cell type assignment
#' @param minimum_delta The minimum difference in probabilities necessary to call one celltype over another.
#' @export
AssignCellType <- function(seuratObj, probabilityColumns, fieldName = 'RIRA_Consensus', minimum_probability = 0.5, minimum_delta = 0.25){
probabilities_matrix <- seuratObj@meta.data[,probabilityColumns, drop = F]
if (ncol(probabilities_matrix) == 0) {
stop('Unable to find cell type probability columns!')
}
seuratObj@meta.data[,fieldName] <- "Unassigned"
#Iterate over the cells in the seurat object
toPlot <- NULL
for (cell in 1:nrow(probabilities_matrix)){
#Find the name of the column with maximum probability and grab the celltype and store it as "top_label"
max_probability_column <- which.max(probabilities_matrix[cell,])
max_probability <- max(probabilities_matrix[cell,])
top_label <- strsplit(names(max_probability_column),"_")[[1]][[1]]
second_highest_probability <- max(probabilities_matrix[cell, names(probabilities_matrix) != names(max_probability_column)])
#Check if the cell's highest probability classification exceeds the minimum probabilty set for a confident call.
#Additionally check if the highest probability and second highest probability are at least minimum_delta apart. If not, assign Unknown.
seuratObj@meta.data[cell,fieldName] <- ifelse( ((max_probability >= minimum_probability) & ((max_probability - second_highest_probability) > minimum_delta)), yes = top_label , no = "Unknown")
passed <- seuratObj@meta.data[cell,fieldName] == top_label
toAdd <- data.frame(CellBarcode = colnames(seuratObj)[cell], Top = max_probability, TopLabel = top_label, Second = second_highest_probability, Passed = passed)
if (all(is.null(toPlot))) {
toPlot <- toAdd
} else {
toPlot <- rbind(toPlot, toAdd)
}
}
minVal <- min(c(toPlot$Top, toPlot$Second)) * 0.9 # provide consistent x/y limits
P1 <- ggplot2::ggplot(toPlot, aes(x = Top, y = Second, color = TopLabel, shape = Passed)) +
geom_point() +
ggtitle("Cell Type Probabilities") +
egg::theme_presentation(base_size = 10) +
ylim(minVal, 1) +
xlim(minVal, 1) +
labs(x = 'Highest Probability', y = 'Second Highest Probability', color = 'Top Label', shape = 'Passed?')
print(P1)
return(seuratObj)
}
.ResolveModel <- function(modelFile) {
if (file.exists(modelFile)) {
return(readRDS(file = modelFile))
}
savedModel <- system.file(paste0("models/", modelFile, ".rds"), package = "RIRA")
if (!file.exists(savedModel)) {
stop(paste0('Unable to find model: ', modelFile))
}
return(readRDS(file = savedModel))
}
#' @title Interprets the feature importance of each model
#'
#' @description Interprets the feature importance of each model using DALEX and feature importance permuation
#' @param output_dir The output directory that TrainAllModels saved training data and models into
#' @param plot_type Argument to pass to model_parts(). Ratio or difference is recommended for large feature sets where the base model's AUC loss will be outside the plotting range.
#' @export
InterpretModels <- function(output_dir= "./classifiers", plot_type = "ratio"){
if (endsWith(output_dir, "/")){
output_dir <- gsub(output_dir, pattern = "/$", replacement = "")
}
#Iterate through models in the models directory
model_names <- list.files(paste0(output_dir, "/models"))
for (model_name in model_names){
#Read in model
model <- readRDS(paste0(output_dir, "/models/",model_name))
#Grab celltype from model filename
celltype <- strsplit(model_name,"_")[[1]][[1]]
#Get associated training data
training_data <- readRDS(paste0(output_dir, "/training_data/",celltype,"_Training_Matrix.rds"))
#trim truth column
data <- training_data[, 1:(ncol(training_data)-1)]
y <- training_data[, ncol(training_data)]
#create explainer
explainer <- DALEXtra::explain_mlr3(model = model,
data = data,
y = y,
label = celltype,
colorize = FALSE)
#variable feature importance
parts <- DALEX::model_parts(explainer, type = plot_type)
print(plot(parts, max_vars=12, show_boxplots = FALSE))
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.