R/resample.R

Defines functions resample

Documented in resample

#' @title Resample a Learner on a Task
#'
#' @description
#' Runs a resampling (possibly in parallel):
#' Repeatedly apply [Learner] `learner` on a training set of [Task] `task` to train a model,
#' then use the trained model to predict observations of a test set.
#' Training and test sets are defined by the [Resampling] `resampling`.
#'
#' @param task ([Task]).
#' @param learner ([Learner]).
#' @param resampling ([Resampling]).
#' @template param_store_models
#' @template param_store_backends
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#' @return [ResampleResult].
#'
#' @template section_predict_sets
#' @template section_parallelization
#' @template section_progress_bars
#' @template section_logging
#'
#' @note
#' The fitted models are discarded after the predictions have been computed in order to reduce memory consumption.
#' If you need access to the models for later analysis, set `store_models` to `TRUE`.
#'
#' @template seealso_resample
#' @export
#' @examples
#' task = tsk("penguins")
#' learner = lrn("classif.rpart")
#' resampling = rsmp("cv")
#'
#' # Explicitly instantiate the resampling for this task for reproduciblity
#' set.seed(123)
#' resampling$instantiate(task)
#'
#' rr = resample(task, learner, resampling)
#' print(rr)
#'
#' # Retrieve performance
#' rr$score(msr("classif.ce"))
#' rr$aggregate(msr("classif.ce"))
#'
#' # merged prediction objects of all resampling iterations
#' pred = rr$prediction()
#' pred$confusion
#'
#' # Repeat resampling with featureless learner
#' rr_featureless = resample(task, lrn("classif.featureless"), resampling)
#'
#' # Convert results to BenchmarkResult, then combine them
#' bmr1 = as_benchmark_result(rr)
#' bmr2 = as_benchmark_result(rr_featureless)
#' print(bmr1$combine(bmr2))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
  assert_subset(clone, c("task", "learner", "resampling"))
  task = assert_task(as_task(task, clone = "task" %in% clone))
  learner = assert_learner(as_learner(learner, clone = "learner" %in% clone))
  resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %in% clone))
  assert_flag(store_models)
  assert_flag(store_backends)
  assert_learnable(task, learner)

  set_encapsulation(list(learner), encapsulate)
  if (!resampling$is_instantiated) {
    resampling = resampling$instantiate(task)
  }
  n = resampling$iters
  pb = if (isNamespaceLoaded("progressr")) {
    # NB: the progress bar needs to be created in this env
    pb = progressr::progressor(steps = n)
  } else {
    NULL
  }
  lgr_threshold = map_int(mlr_reflections$loggers, "threshold")

  grid = if (allow_hotstart) {

    lg$debug("Resampling with hotstart enabled.")

    hotstart_grid = map_dtr(seq_len(n), function(iteration) {
      if (!is.null(learner$hotstart_stack)) {
        # search for hotstart learner
        task_hashes = task_hashes(task, resampling)
        start_learner = get_private(learner$hotstart_stack)$.start_learner(learner$clone(), task_hashes[iteration])
      }
      if (is.null(learner$hotstart_stack) || is.null(start_learner)) {
        # no hotstart learners stored or no adaptable model found
        lg$debug("Resampling with hotstarting not possible. Not start learner found.")
        mode = "train"
      } else {
        # hotstart learner found
        lg$debug("Resampling with hotstarting.")
        start_learner$param_set$values = insert_named(start_learner$param_set$values, learner$param_set$values)
        learner = start_learner
        mode = "hotstart"
      }
      data.table(learner = list(learner), mode = mode)
    })

    # null hotstart stack to reduce overhead in parallelization
    walk(hotstart_grid$learner, function(learner) {
      learner$hotstart_stack = NULL
      learner
    })
    hotstart_grid
  } else {
    data.table(learner = replicate(n, learner), mode = "train")
  }

  res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
    MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb)
  )

  data = data.table(
    task = list(task),
    learner = grid$learner,
    learner_state = map(res, "learner_state"),
    resampling = list(resampling),
    iteration = seq_len(n),
    prediction = map(res, "prediction"),
    uhash = UUIDgenerate(),
    param_values = map(res, "param_values"),
    learner_hash = map_chr(res, "learner_hash")
  )

  ResampleResult$new(ResultData$new(data, store_backends = store_backends))
}

Try the mlr3 package in your browser

Any scripts or data that you put into this service are public.

mlr3 documentation built on Nov. 17, 2023, 5:07 p.m.