#' Train Local Experts
#'
#' This function trains a list of local expert models on a binary matrix.
#'
#' This function calls the \code{\link[caret]{train}} function independently on
#' each column of the binary matrix supplied to the object supplied to the
#' \code{bincols} argument, which should have been generated by the
#' \code{\link{BinCols}} function. The default optimization metric is the
#' kappa value, but accuracy is also available. For the \code{method} argument,
#' only classification algorithms can be used (as opposed to regression).
#'
#' @param x The feature matrix
#' @param bincols Binary target matrix created with 'BinCols()' function
#' @param trControl Optional argument to specify a train control object- defaults to cross validated
#' @param n.folds Number of folds if default train control object is used- defaults to 5
#' @param method Type of learning algorithm used for induction- defaults to lda
#' @param metric Optimization metric; can be either "Accuracy" or "Kappa"
#' @param JIT Whether or not just-in-time compilation is enabled
#' @param ... Additional parameters to pass to caret::train
#' @return Returns a list containing \code{train} type objects. List length is
#' dependent on the number of columns in the matrix supplied as the argument to
#' \code{bincols}
#' @export
#'
TrainLEs <- function(x, bincols, trControl = NULL,
method = "lda",
n.folds = 5,
metric = "Kappa",
JIT = FALSE,
...){
# enable just-in-time compilation
if(JIT == TRUE){
compiler::enableJIT(3)
on.exit(compiler::enableJIT(0))
}
# default train control function if one is not specified
if(is.null(trControl)){
trControl <- caret::trainControl(method = "cv", number = n.folds,
returnData = FALSE,
savePredictions = TRUE,
classProbs = TRUE)}
# make sure it's cross validated w/ n.folds > 1
if(trControl$method != "cv" || trControl$number < 2){
stop('Training method must be cross-validation with at least 2 folds')
}
# start timer
t.0 = proc.time()
# base function to train LEs
trainer <- function(y){
set.seed(123)
mod <- caret::train(x = x, y = y, method = method,
metric = metric, trControl = trControl, ...)}
# apply trainer across all columns
models <- lapply(bincols, trainer)
# stop timer, display time
t.final <- proc.time() - t.0
print(t.final)
# output model list
return(models)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.