R/classification.R

# Classification
#'
#' @title classification
#' @description   This function performs the training of the chosen classifier
#' @param df.train   Training dataframe
#' @param formula   A formula of the form y ~ x1 + x2 + ... If users don't inform formula, the first column will be used as Y values and the others columns with x1,x2....xn
#' @param preprocess pre process
#' @param classifier Choice of classifier to be used to train model. Uses  algortims names from Caret package.
#' @param nfolds     Number of folds to be build in crossvalidation
#' @param index index
#' @param search search option "grid" or  "random"
#' @param repeats repeats
#' @param cpu_cores  Number of CPU cores to be used in parallel processing
#' @param tune_length  This argument is the number of levels for each tuning parameters that should be generated by train
#' @param metric metric used to evaluate model fit. For numeric outcome ("RMSE", "Rsquared)
#' @param seeds  seeds seeds
#' @param verbose verbose
#' @keywords Train kappa
#' @importFrom parallel makePSOCKcluster stopCluster
#' @importFrom doParallel registerDoParallel
#' @importFrom foreach registerDoSEQ
#' @importFrom caret trainControl train
#' @details details
#' @author Elpidio Filho, \email{elpidio@ufv.br}
#' @examples
#' \dontrun{
#' library(dplyr)
#' library(labgeo)
#'
#' data("iris")
#' d = iris %>% select(Species, everything())
#' vt = train_test(df = d, p = 0.75,seed = 313)
#' train = vt$train
#' test = vt$test
#' fit = classification(df.train = train, preprocess = c('center', 'scale'),
#'                      classifier = 'rf', nfolds = 5, cpu_cores = 0,
#'                      metric = 'Kappa', tune_length = 3,
#'                      verbose = T )
#' pred = predict(fit, test)
#' obs = test[,1]
#' plot_confusion_matrix(obs, pred)
#' }
#' @export

classification <- function(df.train,
                       formula = NULL,
                       preprocess = NULL,
                       classifier = "rf",
                       rsample = "cv",
                       nfolds = 10,
                       repeats =  NA,
                       index = NULL,
                       cpu_cores = 4,
                       tune_length = 5,
                       search = "grid",
                       metric = "Kappa",
                       seeds = NULL,
                       verbose = FALSE) {
  resample_methods <- c(
    "boot", "boot632", "optimism_boot", "boot_all", "cv",
    "repeatedcv", "LOOCV", "LGOCV", "none", "oob",
    "timeslice", "adaptive_cv", "adaptive_boot",
    "adaptive_LGOCV"
  )

  if (!any(rsample %in% resample_methods)) {
    stop(paste("resample method", rsample, "does not exist"))
  }

  tc <- caret::trainControl(
    method = rsample, number = nfolds,verboseIter = FALSE,
    index = index, seeds = seeds, search = search
  )
  switch(rsample,
         "cv" = {
           tc <- caret::trainControl(
             method = rsample, number = nfolds, verboseIter = FALSE,
             index = index, seeds = seeds, search = search
           )
         },
         "repeatedcv" = {
           if ((repeats == "NA") | (repeats < 2)) {
             stop("You must define the number of repeats greater then 1 ")
           }
           tc <- caret::trainControl(
             method = rsample, number = nfolds,verboseIter = FALSE,
             repeats = repeats,
             index = index, seeds = seeds, search = search, verboseIter =
           )
         },
         "none" = {
           tc <- caret::trainControl(method = rsample, verboseIter = FALSE)
         }
  )

  #inicio <- Sys.time()

  if (cpu_cores > 0) {
    if (get_os() == 'windows'){
      #  cl <- parallel::makePSOCKcluster(cpu_cores)
      cl <- parallel::makeCluster(cpu_cores)
    } else {
      cl <- parallel::makeCluster(cpu_cores, type="FORK")
    }
    doParallel::registerDoParallel(cl)
    on.exit(stopCluster(cl))
  } else {
    cl <- NULL
    foreach::registerDoSEQ()
  }

  if (is.null(formula)) {
    fit <- tryCatch({
      caret::train(
        x = df.train[, -1], y = df.train[, 1],
        method = classifier, metric = metric,
        trControl = tc, tuneLength = tune_length, verbose = FALSE,
        preProcess = preprocess
      )
    },
    error = function(e) {
      print(" ")
      print(e)
      NULL
    }
    )
  } else {
    fit <- tryCatch({
      caret::train(
        form = formula, data = df.train, method = classifier,
        metric = metric, trControl = tc, tuneLength = tune_length, verbose = FALSE,
        preProcess = preprocess
      )
    },
    error = function(e) {
      print(" ")
      print(e)
      NULL
    }
    )
  }

  if (!is.null(cl)) {
    foreach::registerDoSEQ()
  }
  if (verbose == TRUE) {
    # print(paste("time elapsed : ", hms_span(inicio, Sys.time())))
    # print(caret::getTrainPerf(fit))
  }
  return(fit)
}


get_os <- function(){
  sysinf <- Sys.info()
  if (!is.null(sysinf)){
    os <- sysinf['sysname']
    if (os == 'Darwin')
      os <- "osx"
  } else { ## mystery machine
    os <- .Platform$OS.type
    if (grepl("^darwin", R.version$os))
      os <- "osx"
    if (grepl("linux-gnu", R.version$os))
      os <- "linux"
  }
  tolower(os)
}
elpidiofilho/labgeo documentation built on May 14, 2019, 9:35 a.m.