R/knn.model.R

knn.model <- function(preProcess = c("center", "scale", "zv"),
                      description = NULL, ...)
{
  library(FNN)
  library(caret)

  function()
  {
    k <- list(...)[["k"]]
    model.name <- sprintf("knn_%d", k)
    preProcess_ <- NULL
    info_ <- NULL

    train_ <- function(X_train, y)
    {
      if(is.null(preProcess) == FALSE)
      {
        preProcess_ <<- caret::preProcess(X_train, method = preProcess)
        X_train <- predict(preProcess_, X_train)
      }

      info_ <<- list(X_train = X_train,
                     y = y)
      invisible()
    }

    predict_ <- function(X_test)
    {
      if(is.null(preProcess_) == FALSE)
      {
        X_test <- predict(preProcess_, X_test)
      }

      predictions <- knn(
        train = info_$X_train,
        cl = info_$y,
        test = X_test,
        algorithm = "kd_tree",
        ...)

      nn.index <- attr(predictions, "nn.index")
      probabilities <- t(apply(
        nn.index, 1,
        function (idx) as.numeric(table(info_$y[idx])) / k))
      colnames(probabilities) <- levels(info_$y)
      probabilities <- rename.prediction.columns(probabilities, model.name)
      probabilities
    }

    list(
      train_ = train_,
      predict_ = predict_,
      name = model.name,
      description = description
    )
  }
}
rladeira/stacking documentation built on May 27, 2019, 9:28 a.m.