#' @title Tuning and Training the Data
#' @description tuneTrain splits the Data, it is an automatic function for tuning, training, and making predictions, it returns a list containing a model object, data frame and plot.
#' @param data object of class "data.frame" with target variable and predictor variables.
#' @param y character. Target variable.
#' @param p numeric. Proportion of data to be used for training. Default: 0.7
#' @param method character. Type of model to use for classification or regression.
#' @param length integer. Number of values to output for each tuning parameter. If \code{search = "random"} is passed to \code{\link[caret]{trainControl}} through \code{...}, this becomes the maximum number of tuning parameter combinations that are generated by the random search. Default: 10.
#' @param control character. Resampling method to use. Choices include: "boot", "boot632", "optimism_boot", "boot_all", "cv", "repeatedcv", "LOOCV", "LGOCV", "none", "oob", timeslice, "adaptive_cv", "adaptive_boot", or "adaptive_LGOCV". Default: "repeatedcv". See \code{\link[caret]{train}} for specific details on the resampling methods.
#' @param number integer. Number of cross-validation folds or number of resampling iterations. Default: 10.
#' @param repeats integer. Number of folds for repeated k-fold cross-validation if "repeatedcv" is chosen as the resampling method in \code{control}. Default: 10.
#' @param summary expression. Computes performance metrics across resamples. For numeric \code{y}, the mean squared error and R-squared are calculated. For factor \code{y}, the overall accuracy and Kappa are calculated. See \code{\link[caret]{trainControl}} and \code{\link[caret]{defaultSummary}} for details on specification and summary options. Default: multiClassSummary.
#' @param process character. Defines the pre-processing transformation of predictor variables to be done. Options are: "BoxCox", "YeoJohnson", "expoTrans", "center", "scale", "range", "knnImpute", "bagImpute", "medianImpute", "pca", "ica", or "spatialSign". See \code{\link[caret]{preProcess}} for specific details on each pre-processing transformation. Default: c('center', 'scale').
#' @param positive character. The positive class for the target variable if \code{y} is factor. Usually, it is the first level of the factor.
#' @param parallelComputing logical. indicates whether to also use the parallel processing. Default: False
#' @param ... additional arguments to be passed to \code{createDataPartition}, \code{trainControl} and \code{train} functions in the package \code{caret}.
#' @return A list object with results from tuning and training the model selected in \code{method}, together with predictions and class probabilities. The training and test data sets obtained from splitting the data are also returned.
#'
#' If \code{y} is factor, class probabilities are calculated for each class. If \code{y} is numeric, predicted values are calculated.
#'
#' A ROC curve is created if \code{y} is factor. Otherwise, a plot of residuals versus predicted values is created if \code{y} is numeric.
#'
#' \code{tuneTrain} relies on packages \code{caret}, \code{ggplot2} and \code{plotROC} to perform the modelling and plotting.
#' @details Types of classification and regression models available for use with \code{tuneTrain} can be found using \code{names(getModelInfo())}. The results given depend on the type of model used.
#'
#' For classification models, class probabilities and ROC curve are given in the results. For regression models, predictions and residuals versus predicted plot are given. \code{y} should be converted to either factor if performing classification or numeric if performing regression before specifying it in \code{tuneTrain}.
#'
#' @author Zakaria Kehel, Bancy Ngatia, Khadija Aziz
#' @examples
#' if(interactive()){
#' data(septoriaDurumWC)
#' knn.mod <- tuneTrain(data = septoriaDurumWC,y = 'ST_S',method = 'knn',positive = 'R')
#'
#' nnet.mod <- tuneTrain(data = septoriaDurumWC,y = 'ST_S',method = 'nnet',positive = 'R')
#'
#' }
#' @seealso
#' \code{\link[caret]{createDataPartition}},
#' \code{\link[caret]{trainControl}},
#' \code{\link[caret]{train}},
#' \code{\link[caret]{predict.train}},
#' \code{\link[ggplot2]{ggplot}},
#' \code{\link[plotROC]{geom_roc}},
#' \code{\link[plotROC]{calc_auc}}
#' @rdname tuneTrain
#' @export
#' @importFrom caret createDataPartition trainControl train predict.train
#' @importFrom utils stack
#' @importFrom ggplot2 ggplot aes geom_histogram theme_bw scale_colour_brewer scale_fill_brewer labs coord_equal annotate geom_point
#' @importFrom plotROC geom_roc style_roc calc_auc
#' @importFrom stats resid
#' @importFrom foreach registerDoSEQ
#' @importFrom doParallel registerDoParallel
#' @importFrom parallel detectCores makeCluster stopCluster
tuneTrain <- function (data, y, p = 0.7, method = method, parallelComputing = FALSE,
length = 10, control = "repeatedcv", number = 10,
repeats = 10, process = c('center', 'scale'),
summary= multiClassSummary,positive, ...)
{
set.seed(1234)
x = data[which(colnames(data)!= y)]
yvec = data[[y]]
trainIndex = caret::createDataPartition(y = yvec, p = p,list = FALSE)
data.train = as.data.frame(data[trainIndex, ])
data.test = as.data.frame(data[-trainIndex, ])
split.data = list(trainset = data.train, testset = data.test)
trainset = split.data$trainset
testset = split.data$testset
Train_Index <- row.names(data.train)
trainx = trainset[colnames(trainset) %in% colnames(x)]
trainy = trainset[[y]]
testx = testset[colnames(testset) %in% colnames(x)]
testy = testset[[y]]
if (parallelComputing == TRUE) {
cores <- parallel::detectCores()
cls <- parallel::makeCluster(cores - 4)
doParallel::registerDoParallel(cls)
}
ctrl = caret::trainControl(method = control, number = number,
repeats = repeats)
if (method == "treebag") {
tune.mod = caret::train(trainx, trainy, method = method,
tuneLength = length, trControl = ctrl, preProcess = process , ...)
train.mod <- tune.mod
}
else if (method == "nnet") {
tune.mod = caret::train(trainx, trainy, method = method,
tuneLength = length, trControl = ctrl,
preProcess = process, trace = FALSE)
size <- tune.mod[["bestTune"]][["size"]]
if (size - 1 <= 0) {
seqStart <- size
}
else {
seqStart <- size - 1
}
seqStop <- size + 1
seqInt <- 1
tuneGrid <- expand.grid(.size = seq(seqStart, seqStop, by=seqInt),
.decay = 0.1^(seq(0.01, 0.08, 0.01))*0.11)
ctrl2 = caret::trainControl(method = control, number = number,
repeats = repeats, classProbs = TRUE,
summaryFunction = summary)
train.mod = caret::train(trainx, trainy, method,
tuneGrid = tuneGrid, tuneLength = length, trControl = ctrl2,
preProcess = process, trace = FALSE, ...)
}
else {
tune.mod = caret::train(trainx, trainy, method = method,
tuneLength = length, trControl = ctrl, preProcess = process)
if (method == "knn") {
k <- tune.mod[["bestTune"]][["k"]]
if (k - 2 <= 0) {
seqStart <- k
}
else {
seqStart <- k - 2
}
seqStop <- k + 2
seqInt <- 1
tuneGrid <- expand.grid(.k = seq(seqStart, seqStop, by=seqInt))
}
else if (method == "rf") {
mtry <- tune.mod[["bestTune"]][["mtry"]]
if (mtry - 2 <= 0) {
seqStart <- mtry
}
else {
seqStart <- mtry - 2
}
seqStop <- mtry + 2
seqInt <- 1
tuneGrid <- expand.grid(.mtry = seq(seqStart, seqStop, by=seqInt))
}
else if (method == "svmLinear2") {
cost <- tune.mod[["bestTune"]][["cost"]]
if (cost - 0.25 <= 0 || cost - 0.5 <= 0 || cost - 0.75 <= 0 || cost - 1 <= 0) {
seqStart <- cost
}
else {
seqStart <- cost - 1
}
seqStop <- cost + 1
seqInt <- 0.25
tuneGrid <- expand.grid(.cost = seq(seqStart, seqStop, by=seqInt))
}
ctrl2 = caret::trainControl(method = control, number = number,
repeats = repeats, classProbs = TRUE,
summaryFunction = summary)
train.mod = caret::train(trainx, trainy, method,
tuneGrid = tuneGrid, tuneLength = length, trControl = ctrl2,
preProcess = process, ...)
}
if (parallelComputing == TRUE) {
parallel::stopCluster(cls)
registerDoSEQ()
}
if (is.factor(data[[y]])) {
if (missing(positive)) {
warning("The positive class is not defined!", immediate. = TRUE, noBreaks. = T)
positive <- readline(prompt="Please define the positive class for the target variable: ")
}
prob.mod = as.data.frame(caret::predict.train(train.mod,testx, type = "prob"))
prob.newdf = utils::stack(prob.mod)
colnames(prob.newdf) = c("Probability", "Class")
prob.hist = ggplot2::ggplot(prob.newdf, ggplot2::aes(x = Probability,
colour = Class, fill = Class))
prob.plot = prob.hist + ggplot2::geom_histogram(alpha = 0.4,
size = 1, position = "identity") +
ggplot2::theme_bw() +
ggplot2::scale_colour_brewer(palette = "Dark2") +
ggplot2::scale_fill_brewer(palette = "Dark2") +
ggplot2::labs(y = "Count")
negative = prob.mod[, !names(prob.mod) %in% positive]
if (length(levels(data[,c(1)])) == 2) {
g1 = ggplot2::ggplot(prob.mod, ggplot2::aes(m = negative,
d = testy)) + plotROC::geom_roc(n.cuts = 0) +
ggplot2::coord_equal() + plotROC::style_roc()
plot.roc = g1 + ggplot2::annotate("text", x = 0.75,
y = 0.25, label = paste("AUC =",
round(plotROC::calc_auc(g1)$AUC, 4)))
auc = round(plotROC::calc_auc(g1)$AUC, 4)
x = list(Tuning = tune.mod,
Model = train.mod,
`Class Probabilities` = prob.mod,
`Class Probabilities Plot` = prob.plot,
`Area Under ROC Curve` = auc,
`ROC Curve` = plot.roc,
`TrainingIndex`= Train_Index,
`Training Data` = trainset,
`Test Data` = testset)
}
else if(length(levels(data[,c(1)])) != 2){
x = list(Tuning = tune.mod,
Model = train.mod,
`Class Probabilities` = prob.mod,
`Class Probabilities Plot` = prob.plot,
`TrainingIndex`= Train_Index,
`Training Data` = trainset,
`Test Data` = testset)
}
return(x)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.