R/ResamplingCV.R

#' @title Cross Validation Resampling
#'
#' @usage NULL
#' @aliases mlr_resamplings_cv
#' @format [R6::R6Class] inheriting from [Resampling].
#' @include Resampling.R
#'
#' @section Construction:
#' ```
#' ResamplingCV$new()
#' mlr_resamplings$get("cv")
#' rsmp("cv")
#' ```
#'
#' @description
#' Splits data using a `folds`-folds (default: 10 folds) cross-validation.
#'
#' @section Fields:
#' See [Resampling].
#'
#' @section Methods:
#' See [Resampling].
#'
#' @section Parameters:
#' * `stratify` :: `logical(1)` | `character()`\cr
#'   Enables stratification. See [Resampling].
#' * `folds` :: `integer(1)`\cr
#'   Number of folds.
#'
#' @template seealso_resampling
#' @export
#' @examples
#' # Create a task with 10 observations
#' task = tsk("iris")
#' task$filter(1:10)
#'
#' # Instantiate Resampling
#' rcv = rsmp("cv", folds = 3)
#' rcv$instantiate(task)
#'
#' # Individual sets:
#' rcv$train_set(1)
#' rcv$test_set(1)
#' intersect(rcv$train_set(1), rcv$test_set(1))
#'
#' # Internal storage:
#' rcv$instance # table
ResamplingCV = R6Class("ResamplingCV", inherit = Resampling,
  public = list(
    initialize = function() {
      ps = ParamSet$new(list(
        ParamUty$new("stratify", default = NULL),
        ParamInt$new("folds", lower = 1L, tags = "required")
      ))
      ps$values = list(folds = 10L)

      super$initialize(id = "cv", param_set = ps)
    }
  ),

  active = list(
    iters = function() {
      as.integer(self$param_set$values$folds)
    }
  ),

  private = list(
    .sample = function(ids) {
      data.table(
        row_id = ids,
        fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L),
        key = "fold"
      )
    },

    .get_train = function(i) {
      self$instance[!list(i), "row_id", on = "fold"][[1L]]
    },

    .get_test = function(i) {
      self$instance[list(i), "row_id", on = "fold"][[1L]]
    },

    .combine = function(instances) {
      rbindlist(instances, use.names = TRUE)
    },

    deep_clone = function(name, value) {
      if (name == "instance") copy(value) else value
    }
  )
)

#' @include mlr_resamplings.R
mlr_resamplings$add("cv", ResamplingCV)
mllg/mlr3 documentation built on Sept. 27, 2019, 9:38 a.m.