R/tuneTrain.R

#' @title Tuning and Training the Data
#' @description tuneTrain splits the Data, it is an automatic function for tuning, training, and making predictions, it returns a list containing a model object, data frame and plot.
#' @param data object of class "data.frame" with target variable and predictor variables.
#' @param y character. Target variable.
#' @param p numeric. Proportion of data to be used for training. Default: 0.7
#' @param method character. Type of model to use for classification or regression.
#' @param length integer. Number of values to output for each tuning parameter. If \code{search = "random"} is passed to \code{\link[caret]{trainControl}} through \code{...}, this becomes the maximum number of tuning parameter combinations that are generated by the random search. Default: 10.
#' @param control character. Resampling method to use. Choices include: "boot", "boot632", "optimism_boot", "boot_all", "cv", "repeatedcv", "LOOCV", "LGOCV", "none", "oob", timeslice, "adaptive_cv", "adaptive_boot", or "adaptive_LGOCV". Default: "repeatedcv". See \code{\link[caret]{train}} for specific details on the resampling methods.
#' @param number integer. Number of cross-validation folds or number of resampling iterations. Default: 10.
#' @param repeats integer. Number of folds for repeated k-fold cross-validation if "repeatedcv" is chosen as the resampling method in \code{control}. Default: 10.
#' @param summary expression. Computes performance metrics across resamples. For numeric \code{y}, the mean squared error and R-squared are calculated. For factor \code{y}, the overall accuracy and Kappa are calculated. See \code{\link[caret]{trainControl}} and \code{\link[caret]{defaultSummary}} for details on specification and summary options. Default: multiClassSummary.
#' @param process character. Defines the pre-processing transformation of predictor variables to be done. Options are: "BoxCox", "YeoJohnson", "expoTrans", "center", "scale", "range", "knnImpute", "bagImpute", "medianImpute", "pca", "ica", or "spatialSign". See \code{\link[caret]{preProcess}} for specific details on each pre-processing transformation. Default: c('center', 'scale').
#' @param positive character. The positive class for the target variable if \code{y} is factor. Usually, it is the first level of the factor.
#' @param parallelComputing logical. indicates whether to also use the parallel processing. Default: False
#' @param ... additional arguments to be passed to \code{createDataPartition}, \code{trainControl} and \code{train} functions in the package \code{caret}.
#' @return A list object with results from tuning and training the model selected in \code{method}, together with predictions and class probabilities. The training and test data sets obtained from splitting the data are also returned.
#'
#' If \code{y} is factor, class probabilities are calculated for each class. If \code{y} is numeric, predicted values are calculated.
#'
#' A ROC curve is created if \code{y} is factor. Otherwise, a plot of residuals versus predicted values is created if \code{y} is numeric.
#'
#' \code{tuneTrain} relies on packages \code{caret}, \code{ggplot2} and \code{plotROC} to perform the modelling and plotting.
#' @details Types of classification and regression models available for use with \code{tuneTrain} can be found using \code{names(getModelInfo())}. The results given depend on the type of model used.
#'
#' For classification models, class probabilities and ROC curve are given in the results. For regression models, predictions and residuals versus predicted plot are given. \code{y} should be converted to either factor if performing classification or numeric if performing regression before specifying it in \code{tuneTrain}.
#'
#' @author Zakaria Kehel, Bancy Ngatia, Khadija Aziz
#' @examples
#' if(interactive()){
#'  data(septoriaDurumWC)
#'  knn.mod <- tuneTrain(data = septoriaDurumWC,y = 'ST_S',method = 'knn',positive = 'R')
#'  
#'  nnet.mod <- tuneTrain(data = septoriaDurumWC,y = 'ST_S',method = 'nnet',positive = 'R')
#'
#' }
#' @seealso
#'  \code{\link[caret]{createDataPartition}},
#'  \code{\link[caret]{trainControl}},
#'  \code{\link[caret]{train}},
#'  \code{\link[caret]{predict.train}},
#'  \code{\link[ggplot2]{ggplot}},
#'  \code{\link[plotROC]{geom_roc}},
#'  \code{\link[plotROC]{calc_auc}}
#' @rdname tuneTrain
#' @export
#' @importFrom caret createDataPartition trainControl train predict.train
#' @importFrom utils stack
#' @importFrom ggplot2 ggplot aes geom_histogram theme_bw scale_colour_brewer scale_fill_brewer labs coord_equal annotate geom_point
#' @importFrom plotROC geom_roc style_roc calc_auc
#' @importFrom stats resid
#' @importFrom foreach registerDoSEQ
#' @importFrom doParallel registerDoParallel
#' @importFrom parallel detectCores makeCluster stopCluster


tuneTrain <- function (data, y, p = 0.7, method = method, parallelComputing = FALSE,
                       length = 10, control = "repeatedcv", number = 10, 
                       repeats = 10, process = c('center', 'scale'),
                       summary= multiClassSummary,positive, ...) 
{
    set.seed(1234) 
    x = data[which(colnames(data)!= y)]
    yvec = data[[y]]
    trainIndex = caret::createDataPartition(y = yvec, p = p,list = FALSE)
    data.train = as.data.frame(data[trainIndex, ])
    data.test = as.data.frame(data[-trainIndex, ])
    split.data = list(trainset = data.train, testset = data.test)
    trainset = split.data$trainset
    testset = split.data$testset
    Train_Index <- row.names(data.train)
    trainx = trainset[colnames(trainset) %in% colnames(x)]
    trainy = trainset[[y]]
    testx = testset[colnames(testset) %in% colnames(x)]
    testy = testset[[y]]

    if (parallelComputing == TRUE) {
        cores <- parallel::detectCores()
        cls <- parallel::makeCluster(cores - 4)
        doParallel::registerDoParallel(cls)
    }
    ctrl = caret::trainControl(method = control, number = number, 
                               repeats = repeats)
    
    if (method == "treebag") {
        tune.mod = caret::train(trainx, trainy, method = method, 
                                tuneLength = length, trControl = ctrl, preProcess = process , ...)
        train.mod <- tune.mod
        
    }
    else if (method == "nnet") {
        tune.mod = caret::train(trainx, trainy, method = method, 
                                tuneLength = length, trControl = ctrl, 
                                preProcess = process, trace = FALSE)
        
        size <- tune.mod[["bestTune"]][["size"]]
        
        if (size - 1 <= 0) {
            seqStart <- size
        }
        else {
            seqStart <- size - 1
        }
        
        seqStop <- size + 1
        seqInt <- 1
        tuneGrid <- expand.grid(.size = seq(seqStart, seqStop, by=seqInt), 
                                .decay = 0.1^(seq(0.01, 0.08, 0.01))*0.11)
        
        ctrl2 = caret::trainControl(method = control, number = number, 
                                    repeats = repeats, classProbs = TRUE, 
                                    summaryFunction = summary)
        
        train.mod = caret::train(trainx, trainy, method, 
                                 tuneGrid = tuneGrid, tuneLength = length, trControl = ctrl2, 
                                 preProcess = process, trace = FALSE, ...)
        
    }
    else {
        tune.mod = caret::train(trainx, trainy, method = method, 
                                tuneLength = length, trControl = ctrl, preProcess = process)
        
        
        if (method == "knn") {
            k <- tune.mod[["bestTune"]][["k"]]
            
            if (k - 2 <= 0) {
                seqStart <- k
            }
            else {
                seqStart <- k - 2
            }
            seqStop <- k + 2
            seqInt <- 1
            tuneGrid <- expand.grid(.k = seq(seqStart, seqStop, by=seqInt))
        }
        else if (method == "rf") {
            mtry <- tune.mod[["bestTune"]][["mtry"]]
            
            if (mtry - 2 <= 0) {
                seqStart <- mtry
            }
            else {
                seqStart <- mtry - 2
            }
            
            seqStop <- mtry + 2
            seqInt <- 1
            tuneGrid <- expand.grid(.mtry = seq(seqStart, seqStop, by=seqInt))
        }
        else if (method == "svmLinear2") {
            cost <- tune.mod[["bestTune"]][["cost"]]
            
            if (cost - 0.25 <= 0 || cost - 0.5 <= 0 || cost - 0.75 <= 0 || cost - 1 <= 0) {
                seqStart <- cost
            }
            else {
                seqStart <- cost - 1
            }
            seqStop <- cost + 1
            seqInt <- 0.25
            tuneGrid <- expand.grid(.cost = seq(seqStart, seqStop, by=seqInt))
        }
        
        ctrl2 = caret::trainControl(method = control, number = number, 
                                    repeats = repeats, classProbs = TRUE, 
                                    summaryFunction = summary)
        
        train.mod = caret::train(trainx, trainy, method, 
                                 tuneGrid = tuneGrid, tuneLength = length, trControl = ctrl2, 
                                 preProcess = process, ...)
        
    }
    if (parallelComputing == TRUE) {
        parallel::stopCluster(cls)
        registerDoSEQ()
        
    } 
    
    if (is.factor(data[[y]])) {
        if (missing(positive)) {
            warning("The positive class is not defined!", immediate. = TRUE, noBreaks. = T)
            positive <- readline(prompt="Please define the positive class for the target variable: ")
        }
        prob.mod = as.data.frame(caret::predict.train(train.mod,testx, type = "prob"))
        prob.newdf = utils::stack(prob.mod)
        colnames(prob.newdf) = c("Probability", "Class")
        prob.hist = ggplot2::ggplot(prob.newdf, ggplot2::aes(x = Probability,
                                                             colour = Class, fill = Class))
        prob.plot = prob.hist + ggplot2::geom_histogram(alpha = 0.4, 
                                                        size = 1, position = "identity") + 
            ggplot2::theme_bw() + 
            ggplot2::scale_colour_brewer(palette = "Dark2") + 
            ggplot2::scale_fill_brewer(palette = "Dark2") + 
            ggplot2::labs(y = "Count")
        negative = prob.mod[, !names(prob.mod) %in% positive]
        if (length(levels(data[,c(1)])) == 2) {
            g1 = ggplot2::ggplot(prob.mod, ggplot2::aes(m = negative, 
                                                        d = testy)) + plotROC::geom_roc(n.cuts = 0) + 
                ggplot2::coord_equal() + plotROC::style_roc()
            plot.roc = g1 + ggplot2::annotate("text", x = 0.75, 
                                              y = 0.25, label = paste("AUC =", 
                                                                      round(plotROC::calc_auc(g1)$AUC, 4)))
            auc = round(plotROC::calc_auc(g1)$AUC, 4)
            x = list(Tuning = tune.mod, 
                     Model = train.mod, 
                     `Class Probabilities` = prob.mod, 
                     `Class Probabilities Plot` = prob.plot, 
                     `Area Under ROC Curve` = auc, 
                     `ROC Curve` = plot.roc,
                     `TrainingIndex`= Train_Index,
                     `Training Data` = trainset, 
                     `Test Data` = testset)
        }
        
        else if(length(levels(data[,c(1)])) != 2){
            x = list(Tuning = tune.mod, 
                     Model = train.mod, 
                     `Class Probabilities` = prob.mod, 
                     `Class Probabilities Plot` = prob.plot, 
                     `TrainingIndex`= Train_Index,
                     `Training Data` = trainset, 
                     `Test Data` = testset)
        }
        
        return(x)
    }
}

Try the icardaFIGSr package in your browser

Any scripts or data that you put into this service are public.

icardaFIGSr documentation built on Dec. 11, 2021, 9:21 a.m.