
#' K Nearest Neighbours Trainer
#' @description Trains a k nearest neighbour model using fast search algorithms. KNN is a supervised learning
#'              algorithm which is used for both regression and classification problems.
#' @format \code{\link{R6Class}} object.
#' @section Usage:
#' For usage details see \bold{Methods, Arguments and Examples} sections.
#' \preformatted{
#' bst = KNNTrainer$new(k=1, prob=FALSE, algorithm=NULL, type="class")
#' bst$fit(X_train, X_test, "target")
#' bst$predict(type)
#' }
#' @section Methods:
#' \describe{
#'     \item{\code{$new()}}{Initialise the instance of the trainer}
#'     \item{\code{$fit()}}{trains the knn model and stores the test prediction}
#'     \item{\code{$predict()}}{returns predictions}
#' }
#' @section Arguments:
#' \describe{
#'     \item{k}{number of neighbours to predict}
#'     \item{prob}{if probability should be computed, default=FALSE}
#'     \item{algorithm}{algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'}
#'     \item{type}{type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'}
#' }
#' @export
#' @examples
#' data("iris")
#' iris$Species <- as.integer(as.factor(iris$Species))
#' xtrain <- iris[1:100,]
#' xtest <- iris[101:150,]
#' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
#' bst$fit(xtrain, xtest, 'Species')
#' pred <- bst$predict(type="raw")
KNNTrainer <- R6Class("KNNTrainer", public = list(

    #' @field k number of neighbours to predict
    k = 1,
    #' @field prob if probability should be computed, default=FALSE
    prob = FALSE,
    #' @field algorithm algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'
    algorithm = NULL,
    #' @field type type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'
    type = "class",
    #' @field model for internal use
    model = NA,

    #' @details
    #' Create a new `KNNTrainer` object.
    #' @param k k number of neighbours to predict
    #' @param prob if probability should be computed, default=FALSE
    #' @param algorithm algorithm used to train the model, possible values are 'kd_tree','cover_tree','brute'
    #' @param type type of problem to solve i.e. regression or classification, possible values are 'reg' or 'class'
    #' @return A `KNNTrainer` object.
    #' @examples
    #' data("iris")
    #' iris$Species <- as.integer(as.factor(iris$Species))
    #' xtrain <- iris[1:100,]
    #' xtest <- iris[101:150,]
    #' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
    #' bst$fit(xtrain, xtest, 'Species')
    #' pred <- bst$predict(type="raw")

    initialize = function(k, prob, algorithm, type){
        if(!(missing(k))) self$k <- k
        if(!(missing(prob))) self$prob <- prob
        if(!(missing(algorithm))) self$algorithm <- algorithm
        if(!(missing(type))) self$type <- type

    #' @details
    #' Trains the KNNTrainer model
    #' @param train data.frame or matrix
    #' @param test data.frame or matrix
    #' @param y character, name of target variable
    #' @return NULL
    #' @examples
    #' data("iris")
    #' iris$Species <- as.integer(as.factor(iris$Species))
    #' xtrain <- iris[1:100,]
    #' xtest <- iris[101:150,]
    #' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
    #' bst$fit(xtrain, xtest, 'Species')

    fit = function(train, test, y){

        data <- private$prepare_data(train, test, y)

        if(self$type == "class"){
            self$model <- FNN::knn(train = data$train
                              ,test = data$test
                              ,cl = data$y
                              ,k = self$k
                              ,prob = self$prob
                              ,algorithm = self$algorithm)
        } else if (self$type == "reg"){
            self$model <- FNN::knn.reg(train = data$train
                                  ,test = data$test
                                  ,y = data$y
                                  ,k = self$k
                                  ,algorithm = self$algorithm)

    #' @details
    #' Predits the nearest neigbours for test data
    #' @param type character, 'raw' for labels else 'prob'
    #' @return a list of predicted neighbours
    #' @examples
    #' data("iris")
    #' iris$Species <- as.integer(as.factor(iris$Species))
    #' xtrain <- iris[1:100,]
    #' xtest <- iris[101:150,]
    #' bst <- KNNTrainer$new(k=3, prob=TRUE, type="class")
    #' bst$fit(xtrain, xtest, 'Species')
    #' pred <- bst$predict(type="raw")

    predict = function(type="raw"){

        if (self$type == "class") {
            if (type == "raw") {
            } else if (type == "prob") {
                return(attr(self$model, "prob"))
        } else if (self$type == "reg") {


    private = list(

        prepare_data = function(train, test, y){

            train <- as.data.table(train)
            test <- as.data.table(test)

            if (!(y %in% names(train)))
                stop(sprintf("%s not available in training data", y))

            # get dependent variable and store temporarily
            y_temp <- train[[y]]

            # select all independent features
            train <- train[,setdiff(names(train), y), with = F]

            # subset from test, just in case if the dependet variable is in test
            test <- test[, setdiff(names(test), y), with = F]

            # set dependent variable to y
            y <- y_temp

            if (ncol(test) != ncol(train))
                stop(sprintf('Train and test data have
                             unequal independent variables.'))

            if (any(vapply(train, is.factor, logical(1)))
               | any(vapply(train, is.character, logical(1))))
                stop("Train data contains non-numeric variables.
                     Please convert them into integer.")

            if (any(vapply(test, is.factor, logical(1)))
               | any(vapply(test, is.character, logical(1))))
                stop("Test data contains non-numeric variables.
                     Please convert them into integer.")

            # check in case target variable contains float values or NA values
            if (any(is.na(y)))
                stop("The target variable contains NA values.")

            if (self$type=="class") {
                if (is.numeric(y)){
                    if (!(all(y == floor(y))))
                        stop("The target variable contains float values")

            return(list(train = train, test = test, y = y))


Try the superml package in your browser

Any scripts or data that you put into this service are public.

superml documentation built on Nov. 14, 2022, 9:05 a.m.