#' @title Train a prediction model
#' @description Trains a prediction model from an \code{scPred} object stored in a \code{Seurat} object
#' @param object An \code{Seurat} or \code{scPred} object after running
#' \code{getFeatureSpace}
#' @param model Classification model supported via \code{caret} package. A list of all models can be found here:
#' @param preProcess A string vector that defines a pre-processing of the predictor data. Current possibilities are
#' "BoxCox", "YeoJohnson", "expoTrans", "center", "scale", "range", "knnImpute", "bagImpute", "medianImpute",
#' "pca", "ica" and "spatialSign". The default is "center" and "scale. See preProcess and trainControl on the
#' procedures and how to adjust them
#' https://topepo.github.io/caret/available-models.html
#' Default: support vector machine with radial kernel
#' @param resampleMethod Resample model used in \code{trainControl} function from \code{caret}.
#' Default: K-fold cross validation
#' @param number Number of iterations for resample method. See \code{trainControl} function
#' @param seed Numeric seed for resample method. Fixed to ensure reproducibility
#' @param tuneLength An integer denoting the amount of granularity in the tuning parameter grid.
#' By default, this argument is the number of levels for each tuning parameters that should be generated by train.
#' See `?caret::train` documentation
#' @param metric Performance metric to be used to select best model: `ROC` (area under the ROC curve),
#' `PR` (area under the precision-recall curve), `Accuracy`, and `Kappa`
#' @param returnData If \code{TRUE}, training data is returned within \code{scPred} object.
#' @param savePredictions Specifies the set of hold-out predictions for each resample that should be
#' returned. Values can be either "all", "final", or "none".
#' @param allowParallel Allow parallel processing for resampling?
#' @param reclassify Cell types to reclassify using a different model
#' @return A list of \code{train} objects for each cell class (e.g. cell type). See \code{train} function for details.
#' @keywords train, model
#' @importFrom methods is
#' @importFrom caret trainControl prSummary train twoClassSummary
#' @importFrom pbapply pblapply
#' @export
#' @author
#' Jose Alquicira Hernandez
trainModel <- function(object,
model = "svmRadial",
preProcess = c("center", "scale"),
resampleMethod = "cv",
number = 5,
seed = 66,
tuneLength = 3,
metric = c("ROC", "PR", "Accuracy", "Kappa"),
returnData = FALSE,
savePredictions = "final",
allowParallel = FALSE,
reclassify = NULL
){
# Validations -------------------------------------------------------------
# Check class
if(!is(object, "Seurat") | is(object, "scPred")){
stop("object must be 'Seurat' or 'scPred'")
}
if(is(object, "Seurat")){
seurat_object <- object
object <- get_scpred(object)
if(is.null(object))
stop("No features have been determined. Use 'getFeatureSpace()' function")
object_class <- "Seurat"
}else{
object_class <- "scPred"
}
if(is.null(reclassify)){
classes <- names(object@features)
}else{
classes <- reclassify
}
metric <- match.arg(metric)
reduction <- object@reduction
# Train a prediction model for each class
cat(crayon::green(cli::symbol$record, " Training models for each cell type...\n"))
if(length(classes) == 1){
modelsRes <- .trainModel(classes[1],
object,
model,
reduction,
preProcess,
resampleMethod,
tuneLength,
seed,
metric,
number,
returnData,
savePredictions,
allowParallel)
modelsRes <- list(modelsRes)
names(modelsRes) <- classes[1]
}else{
modelsRes <- pblapply(classes, .trainModel,
object,
model,
reduction,
preProcess,
resampleMethod,
tuneLength,
seed,
metric,
number,
returnData,
savePredictions,
allowParallel)
names(modelsRes) <- classes
}
cat(crayon::green("DONE!\n"))
if(is.null(reclassify)){
object@train <- modelsRes
}else{
object@train[names(modelsRes)] <- modelsRes
}
if(object_class == "Seurat"){
seurat_object@misc$scPred <- object
seurat_object
}else{
object
}
}
.trainModel <- function(positiveClass,
spmodel,
model,
reduction,
preProcess,
resampleMethod,
tuneLength,
seed,
metric,
number,
returnData,
savePredictions,
allowParallel){
if(nrow(spmodel@features[[positiveClass]]) == 0){
message("No informative principal components were identified for class: ", positiveClass)
}
names_features <- as.character(spmodel@features[[positiveClass]]$feature)
features <- scPred:::subsetMatrix(spmodel@cell_embeddings, names_features)
response <- as.character(spmodel@metadata$response)
i <- response != .make_names(positiveClass)
response[i] <- "other"
response <- factor(response, levels = c(.make_names(positiveClass), "other"))
if(!is.null(seed)) set.seed(seed)
if(metric == "ROC"){
trCtrl <- trainControl(classProbs = TRUE,
method = resampleMethod,
number = number,
summaryFunction = twoClassSummary,
returnData = returnData,
savePredictions = savePredictions,
allowParallel = allowParallel)
}else if(metric == "PR"){
trCtrl <- trainControl(classProbs = TRUE,
method = resampleMethod,
number = number,
summaryFunction = prSummary,
returnData = returnData,
savePredictions = savePredictions,
allowParallel = allowParallel)
metric <- "AUC"
}else{
trCtrl <- trainControl(classProbs = TRUE,
method = resampleMethod,
number = number,
returnData = returnData,
savePredictions = savePredictions,
allowParallel = allowParallel)
}
if(metric == "AUC"){
fit <- train(x = features,
y = response,
method = model,
metric = metric,
trControl = trCtrl,
preProcess = preProcess,
tuneLength = tuneLength)
}else{
fit <- train(x = features,
y = response,
method = model,
preProcess= preProcess,
metric = metric,
trControl = trCtrl,
tuneLength = tuneLength)
}
fit
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.