R/TrainLEs.R

#' Train Local Experts
#'
#' This function trains a list of local expert models on a binary matrix.
#'
#' This function calls the \code{\link[caret]{train}} function independently on
#' each column of the binary matrix supplied to the object supplied to the
#' \code{bincols} argument, which should have been generated by the
#' \code{\link{BinCols}} function. The default optimization metric is the
#' kappa value, but accuracy is also available. For the \code{method} argument,
#' only classification algorithms can be used (as opposed to regression).
#'
#' @param x The feature matrix
#' @param bincols Binary target matrix created with 'BinCols()' function
#' @param trControl Optional argument to specify a train control object- defaults to cross validated
#' @param n.folds Number of folds if default train control object is used- defaults to 5
#' @param method Type of learning algorithm used for induction- defaults to lda
#' @param metric Optimization metric; can be either "Accuracy" or "Kappa"
#' @param JIT Whether or not just-in-time compilation is enabled
#' @param ... Additional parameters to pass to caret::train
#' @return Returns a list containing \code{train} type objects. List length is
#' dependent on the number of columns in the matrix supplied as the argument to
#' \code{bincols}
#' @export
#'

TrainLEs <- function(x, bincols, trControl = NULL,
                     method = "lda",
                     n.folds = 5,
                     metric = "Kappa",
                     JIT = FALSE,
                     ...){

  # enable just-in-time compilation
  if(JIT == TRUE){
    compiler::enableJIT(3)
    on.exit(compiler::enableJIT(0))
  }

  # default train control function if one is not specified
  if(is.null(trControl)){
    trControl <- caret::trainControl(method = "cv", number = n.folds,
                               returnData = FALSE,
                               savePredictions = TRUE,
                               classProbs = TRUE)}

  # make sure it's cross validated w/ n.folds > 1
  if(trControl$method != "cv" || trControl$number < 2){
    stop('Training method must be cross-validation with at least 2 folds')
  }

  # start timer
  t.0 = proc.time()

  # base function to train LEs
  trainer <- function(y){
    set.seed(123)
    mod <- caret::train(x = x, y = y, method = method,
                        metric = metric, trControl = trControl, ...)}

  # apply trainer across all columns
  models <- lapply(bincols, trainer)

  # stop timer, display time
  t.final <- proc.time() - t.0
  print(t.final)

  # output model list
  return(models)
  }
nnormandin/localexpeRt documentation built on May 23, 2019, 9:29 p.m.