R/LearnerClassifKKNN.R

#' @title k-Nearest-Neighbor Classification Learner
#'
#' @name mlr_learners_classif.kknn
#'
#' @description
#' k-Nearest-Neighbor classification.
#' Calls [kknn::kknn()] from package \CRANpkg{kknn}.
#'
#' @section Initial parameter values:
#' - `store_model`:
#'   - See note.
#'
#' @template note_kknn
#'
#' @templateVar id classif.kknn
#' @template learner
#'
#' @references
#' `r format_bib("hechenbichler_2004", "samworth_2012", "cover_1967")`
#'
#' @export
#' @template seealso_learner
#' @template example
LearnerClassifKKNN = R6Class("LearnerClassifKKNN",
  inherit = LearnerClassif,
  public = list(

    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      ps = ps(
        k           = p_int(default = 7L, lower = 1L, tags = "train"),
        distance    = p_dbl(0, default = 2, tags = "train"),
        kernel      = p_fct(c("rectangular", "triangular", "epanechnikov", "biweight", "triweight", "cos", "inv", "gaussian", "rank", "optimal"), default = "optimal", tags = "train"),
        scale       = p_lgl(default = TRUE, tags = "train"),
        ykernel     = p_uty(default = NULL, tags = "train"),
        store_model = p_lgl(default = FALSE, tags = "train")
      )
      ps$values = list(k = 7L)

      super$initialize(
        id = "classif.kknn",
        param_set = ps,
        predict_types = c("response", "prob"),
        feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
        properties = c("twoclass", "multiclass"),
        packages = c("mlr3learners", "kknn"),
        label = "k-Nearest-Neighbor",
        man = "mlr3learners::mlr_learners_classif.kknn"
      )
    }
  ),

  private = list(
    .train = function(task) {
      # https://github.com/mlr-org/mlr3learners/issues/191
      pv = self$param_set$get_values(tags = "train")
      if (pv$k >= task$nrow) {
        stopf("Parameter k = %i must be smaller than the number of observations (n = %i)",
          pv$k, task$nrow)
      }

      list(
        formula = task$formula(),
        data = task$data(),
        pv = pv,
        kknn = NULL
      )
    },

    .predict = function(task) {
      model = self$state$model
      newdata = ordered_features(task, self)
      pv = insert_named(model$pv, self$param_set$get_values(tags = "predict"))

      with_package("kknn", { # https://github.com/KlausVigo/kknn/issues/16
        p = invoke(kknn::kknn,
          formula = model$formula, train = model$data,
          test = newdata, .args = remove_named(pv, "store_model"))
      })

      if (isTRUE(pv$store_model)) {
        self$state$model$kknn = p
      }

      if (self$predict_type == "response") {
        list(response = p$fitted.values)
      } else {
        list(prob = p$prob)
      }
    }
  )
)

#' @include aaa.R
learners[["classif.kknn"]] = LearnerClassifKKNN

Try the mlr3learners package in your browser

Any scripts or data that you put into this service are public.

mlr3learners documentation built on Nov. 21, 2023, 5:07 p.m.