# Classification
#'
#' @title classification
#' @description This function performs the training of the chosen classifier
#' @param df.train Training dataframe
#' @param formula A formula of the form y ~ x1 + x2 + ... If users don't inform formula, the first column will be used as Y values and the others columns with x1,x2....xn
#' @param preprocess pre process
#' @param classifier Choice of classifier to be used to train model. Uses algortims names from Caret package.
#' @param nfolds Number of folds to be build in crossvalidation
#' @param index index
#' @param repeats repeats
#' @param cpu_cores Number of CPU cores to be used in parallel processing
#' @param tune_length This argument is the number of levels for each tuning parameters that should be generated by train
#' @param metric metric used to evaluate model fit. For numeric outcome ("RMSE", "Rsquared)
#' @param seeds seeds seeds
#' @param verbose verbose
#' @keywords Train kappa
#' @importFrom parallel makePSOCKcluster stopCluster
#' @importFrom doParallel registerDoParallel
#' @importFrom caret trainControl train
#' @details details
#' @author Elpidio Filho, \email{elpidio@ufv.br}
#' @examples
#' \dontrun{
#' kappa_cv_evaluation(train,"rf",10,6)
#' }
#' @export
classification <- function(df.train,
formula = NULL,
preprocess = NULL,
classifier = "rf",
nfolds = 10,
repeats = 1,
index = NULL,
cpu_cores = 4,
tune_length = 5,
metric = "Kappa",
seeds = NULL,
verbose = FALSE) {
cl <- NULL
if (nfolds == 0) {
method <- "none"
tune_length <- 1
}
if (nfolds == nrow(df.train)) {
method <- "LOOCV"
} else {
method <- "CV"
}
#if (repeats > 1) method <- "repeatedcv"
inicio <- Sys.time()
tc <- caret::trainControl(method = method,number = nfolds,
index = index, seeds = seeds )
if (cpu_cores > 0) {
cl <- parallel::makePSOCKcluster(cpu_cores)
doParallel::registerDoParallel(cl)
on.exit(stopCluster(cl))
}
#set.seed(313)
if (is.null(formula == FALSE)) {
fit <- tryCatch({
suppressMessages(caret::train(
formula, data = df.train, method = classifier,
metric = metric,
trControl = tc,
tuneLength = tune_length,
preProcess = preprocess
))},
error = function(e) NULL)
} else {
fit <- tryCatch({
suppressMessages(caret::train(
x = df.train[, -1],
y = df.train[, 1],
method = classifier,
metric = metric,
trControl = tc,
tuneLength = tune_length,
preProcess = preprocess
))},
error = function(e) NULL)
}
#if (!is.null(cl)) {
# parallel::stopCluster(cl)
#}
if (verbose == TRUE & is.null(fit) == FALSE) {
print(paste("Classification variable ", names(df.train)[1]))
print(paste("time elapsed : ", hms_span(inicio, Sys.time())))
print(caret::getTrainPerf(fit))
}
return(fit)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.