R/train.R

#' @title Train a Learner on a Task
#'
#' @description
#' Fits a model on a subset of the task data.
#'
#' @param task ([Task()])\cr
#'   Object of type [Task()].
#' @param learner ([Learner()])\cr
#'   Object of type [Learner()].
#' @param subset (`integer` | `logical`)\cr
#'   Subset of `task` to train the data on.
#' @return [TrainResult()].
#' @export
train = function(task, learner, subset = NULL) {
  assertTask(task)
  assertLearner(learner, for.task = task)
  assertIndexSet(subset, for.task = task)
  row.ids = translateSubset(task, subset)

  trainWorker(task, learner, row.ids)
}

trainWorker = function(task, learner, row.ids) {
  encapsulation = assertString(getOption("mlrng.train.encapsulation", "evaluate"))
  result = NULL

  start.time = proc.time()[3L]
  if (encapsulation == "none") {
    result = fitModel(task, learner, row.ids)
    raw.log = list()
  } else {
    eval.string = getTrainEvalString(encapsulation)
    raw.log = evaluate::evaluate(eval.string, new_device = FALSE)
  }
  train.time = round(proc.time()[3L] - start.time, 6L)

  train.log = TrainLog$new(raw.log, train.time)
  train.success = !is.null(result) && train.log$n.errors == 0L

  if (train.success) {
    ginfo("Trained {learner$id} on {task$id} with {train.log$n.errors} errors, {train.log$n.warnings} warnings and {train.log$n.messages} messages.")
  } else {
   if (getOption("mlrng.continue.on.learner.error", FALSE)) {
      result = fitModel(task, createFallbackLearner(task), row.ids)
      ginfo("Training {learner$id} on {task$id} failed, fallback to dummy model.")
    } else {
      gstop("Training {learner$id} on {task$id} failed with {train.log$errors[[1]]$message}.")
    }
  }

  TrainResult$new(task, learner, result, row.ids, train.log)
}

fitModel = function(task, learner, row.ids) {
  requireNS(learner$packages)
  pars = c(list(task = task, subset = row.ids), learner$par.vals)
  do.call(learner$train, pars)
}

getTrainEvalString = function(encapsulation) {
  if (encapsulation == "evaluate") {
    "result = fitModel(task, learner, row.ids)"
  } else if (encapsulation == "callr") {
    requireNS("callr")
     "result = callr::r(function(task, learner, row.ids) {
        library(mlrng)
        mlrng:::fitModel(task, learner, row.ids)
        }, list(task = task, learner = learner, row.ids = row.ids))"
  } else {
    gstop("Invalid encapsulation setting '{encapsulation}'", .call = FALSE)
  }
}
mlr-org/mlrng documentation built on May 4, 2019, 4:22 p.m.