R/sofia.model.R

sofia.model <- function(preProcess = c("center", "scale", "zv"),
                        description = NULL, ...)
{
  library(RSofia)
  library(caret)

  function()
  {
    model.name <- "sofia"
    preProcess_ <- NULL
    model_ <- NULL

    train_ <- function(X_train, y)
    {
      if(is.null(preProcess) == FALSE)
      {
        preProcess_ <<- caret::preProcess(X_train, method = preProcess)
        X_train <- predict(preProcess_, X_train)
      }

      model_ <<- sofia.fit(x = as.matrix(X_train),
                           y = as.numeric(y) - 1,
                           ...)
      invisible()
    }

    predict_ <- function(X_test)
    {
      if(is.null(preProcess_) == FALSE)
      {
        X_test <- predict(preProcess_, X_test)
      }
      if(list(...)[["learner_type"]] == "logreg-pegasos")
      {
        predictions <- predict(model_, X_test, prediction_type = "logistic")
      } else {
        predictions <- predict(model_, X_test, prediction_type = "linear")
      }
      predictions <- as.matrix(predictions)
      colnames(predictions) <- paste(model.name, 1:ncol(predictions), sep = "_")
      predictions
    }

    list(
      train_ = train_,
      predict_ = predict_,
      name = model.name,
      description = description
    )
  }
}
rladeira/stacking documentation built on May 27, 2019, 9:28 a.m.