R/trainModel.R

Defines functions .trainModel trainModel

Documented in trainModel

#' @title Train a prediction model
#' @description Trains a prediction model from an \code{scPred} object stored in a \code{Seurat} object
#' @param object An \code{Seurat} or \code{scPred} object after running 
#' \code{getFeatureSpace}
#' @param model Classification model supported via \code{caret} package. A list of all models can be found here:
#' @param preProcess A string vector that defines a pre-processing of the predictor data. Current possibilities are 
#' "BoxCox", "YeoJohnson", "expoTrans", "center", "scale", "range", "knnImpute", "bagImpute", "medianImpute", 
#' "pca", "ica" and "spatialSign". The default is "center" and "scale. See preProcess and trainControl on the 
#' procedures and how to adjust them
#' https://topepo.github.io/caret/available-models.html
#' Default: support vector machine with radial kernel
#' @param resampleMethod Resample model used in \code{trainControl} function from \code{caret}. 
#' Default: K-fold cross validation
#' @param number Number of iterations for resample method. See \code{trainControl} function
#' @param seed Numeric seed for resample method. Fixed to ensure reproducibility
#' @param tuneLength An integer denoting the amount of granularity in the tuning parameter grid. 
#' By default, this argument is the number of levels for each tuning parameters that should be generated by train.
#' See `?caret::train` documentation
#' @param metric Performance metric to be used to select best model: `ROC` (area under the ROC curve), 
#' `PR` (area under the precision-recall curve), `Accuracy`, and `Kappa`
#' @param returnData If \code{TRUE}, training data is returned within \code{scPred} object. 
#' @param savePredictions Specifies the set of hold-out predictions for each resample that should be
#' returned. Values can be either "all", "final", or "none".
#' @param allowParallel Allow parallel processing for resampling?
#' @param reclassify Cell types to reclassify using a different model
#' @return A list of \code{train} objects for each cell class (e.g. cell type). See \code{train} function for details.
#' @keywords train, model
#' @importFrom methods is
#' @importFrom caret trainControl prSummary train twoClassSummary
#' @importFrom pbapply pblapply
#' @export
#' @author
#' Jose Alquicira Hernandez

trainModel <- function(object,
                       model = "svmRadial",
                       preProcess = c("center", "scale"),
                       resampleMethod = "cv",
                       number = 5,
                       seed = 66,
                       tuneLength = 3,
                       metric = c("ROC", "PR", "Accuracy", "Kappa"),
                       returnData = FALSE,
                       savePredictions = "final",
                       allowParallel = FALSE,
                       reclassify = NULL
){
    
    
    # Validations -------------------------------------------------------------
    
    # Check class
    if(!is(object, "Seurat") | is(object, "scPred")){
        stop("object must be 'Seurat' or 'scPred'")
    }
    
    
    if(is(object, "Seurat")){
        seurat_object <- object
        object <- get_scpred(object)
        
        if(is.null(object))
            stop("No features have been determined. Use 'getFeatureSpace()' function")
        
        object_class <- "Seurat"
        
    }else{
        object_class <- "scPred"
    }
    

    if(is.null(reclassify)){
        classes <- names(object@features)
    }else{
        classes <- reclassify
    }
    metric <- match.arg(metric)
    reduction <- object@reduction
    
    # Train a prediction model for each class
    cat(crayon::green(cli::symbol$record, " Training models for each cell type...\n"))
    
    
    
    if(length(classes) == 1){
        modelsRes <-  .trainModel(classes[1],
                                  object,
                                  model,
                                  reduction,
                                  preProcess,
                                  resampleMethod,
                                  tuneLength,
                                  seed,
                                  metric,
                                  number,
                                  returnData,
                                  savePredictions,
                                  allowParallel)
        modelsRes <- list(modelsRes)
        names(modelsRes) <- classes[1]
        
        
    }else{
        modelsRes <- pblapply(classes, .trainModel,
                              object,
                              model,
                              reduction,
                              preProcess,
                              resampleMethod,
                              tuneLength,
                              seed,
                              metric,
                              number,
                              returnData,
                              savePredictions,
                              allowParallel)
        names(modelsRes) <- classes
    }
    
    cat(crayon::green("DONE!\n"))
    
    if(is.null(reclassify)){
        object@train <- modelsRes
    }else{
        object@train[names(modelsRes)] <- modelsRes
    }
    
    if(object_class == "Seurat"){
        seurat_object@misc$scPred <- object
        seurat_object
        
    }else{
        object
    }
}

.trainModel <- function(positiveClass,
                        spmodel,
                        model,
                        reduction,
                        preProcess,
                        resampleMethod,
                        tuneLength,
                        seed,
                        metric,
                        number,
                        returnData,
                        savePredictions, 
                        allowParallel){
    
    
    
    
    if(nrow(spmodel@features[[positiveClass]]) == 0){
        message("No informative principal components were identified for class: ", positiveClass)
    }
    
    names_features <- as.character(spmodel@features[[positiveClass]]$feature)
    features <- scPred:::subsetMatrix(spmodel@cell_embeddings, names_features)
    response <- as.character(spmodel@metadata$response)
    
    
    i <- response != .make_names(positiveClass)
    response[i] <- "other"
    response <- factor(response, levels = c(.make_names(positiveClass), "other"))
    
    
    if(!is.null(seed)) set.seed(seed)
    
    if(metric == "ROC"){
        trCtrl <- trainControl(classProbs = TRUE,
                               method = resampleMethod,
                               number = number,
                               summaryFunction = twoClassSummary,
                               returnData = returnData,
                               savePredictions = savePredictions,
                               allowParallel = allowParallel)
        
    }else if(metric == "PR"){
        trCtrl <- trainControl(classProbs = TRUE,
                               method = resampleMethod,
                               number = number,
                               summaryFunction = prSummary,
                               returnData = returnData,
                               savePredictions = savePredictions,
                               allowParallel = allowParallel)
        metric <- "AUC"
    }else{
        trCtrl <- trainControl(classProbs = TRUE,
                               method = resampleMethod,
                               number = number,
                               returnData = returnData,
                               savePredictions = savePredictions,
                               allowParallel = allowParallel)
    }
    
    
    if(metric == "AUC"){
        fit <- train(x = features,
                     y = response,
                     method = model,
                     metric = metric,
                     trControl = trCtrl,
                     preProcess = preProcess, 
                     tuneLength = tuneLength)
    }else{
        fit <- train(x = features,
                     y = response,
                     method = model,
                     preProcess= preProcess,
                     metric = metric,
                     trControl = trCtrl, 
                     tuneLength = tuneLength)
    }
    
    fit
}
powellgenomicslab/scPred documentation built on July 16, 2021, 12:14 a.m.