#' @title Context for Torch Learner
#'
#' @name mlr_context_torch
#'
#' @description
#' Context for training a torch learner.
#' This is the - mostly read-only - information callbacks have access to through the argument `ctx`.
#' For more information on callbacks, see [`CallbackSet`].
#'
#' @family Callback
#' @export
ContextTorch = R6Class("ContextTorch",
lock_objects = FALSE,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param learner ([`Learner`][mlr3::Learner])\cr
#' The torch learner.
#' @param task_train ([`Task`][mlr3::Task])\cr
#' The training task.
#' @param task_valid ([`Task`][mlr3::Task] or `NULL`)\cr
#' The validation task.
#' @param loader_train ([`torch::dataloader`])\cr
#' The data loader for training.
#' @param loader_valid ([`torch::dataloader`] or `NULL`)\cr
#' The data loader for validation.
#' @param measures_train (`list()` of [`Measure`][mlr3::Measure]s or `NULL`)\cr
#' Measures used for training. Default is `NULL`.
#' @param measures_valid (`list()` of [`Measure`][mlr3::Measure]s or `NULL`)\cr
#' Measures used for validation.
#' @param network ([`torch::nn_module`])\cr
#' The torch network.
#' @param optimizer ([`torch::optimizer`])\cr
#' The optimizer.
#' @param loss_fn ([`torch::nn_module`])\cr
#' The loss function.
#' @param total_epochs (`integer(1)`)\cr
#' The total number of epochs the learner is trained for.
#' @param prediction_encoder (`function()`)\cr
#' The learner's prediction encoder.
#' See section *Inheriting* of [`LearnerTorch`].
#' @param eval_freq (`integer(1)`)\cr
#' The evaluation frequency.
#' @param device (`character(1)`)\cr
#' The device.
initialize = function(learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL,
measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder,
eval_freq = 1L, device) {
self$learner = assert_r6(learner, "Learner")
self$task_train = assert_r6(task_train, "Task")
self$task_valid = assert_r6(task_valid, "Task", null.ok = TRUE)
self$loader_train = assert_class(loader_train, "dataloader")
self$loader_valid = assert_class(loader_valid, "dataloader", null.ok = TRUE)
self$measures_train = assert_list(measures_train, names = "unique", any.missing = FALSE, types = "Measure",
null.ok = TRUE) %??% list()
self$measures_valid = assert_list(measures_valid, names = "unique", any.missing = FALSE, types = "Measure",
null.ok = TRUE) %??% list()
self$network = assert_class(network, "nn_module")
self$optimizer = assert_class(optimizer, "torch_optimizer")
self$loss_fn = assert_class(loss_fn, "nn_module")
self$total_epochs = assert_integerish(total_epochs, lower = 0, any.missing = FALSE)
self$last_scores_train = structure(list(), names = character(0))
self$last_scores_valid = structure(list(), names = character(0))
self$prediction_encoder = assert_function(prediction_encoder, args = c("predict_tensor", "task"))
self$eval_freq = assert_int(eval_freq, lower = 1L)
self$terminate = FALSE
self$device = torch_device(assert_choice(device, mlr_reflections$torch$devices))
},
#' @field learner ([`Learner`][mlr3::Learner])\cr
#' The torch learner.
learner = NULL,
#' @field task_train ([`Task`][mlr3::Task])\cr
#' The training task.
task_train = NULL,
#' @field task_valid ([`Task`][mlr3::Task] or `NULL`)\cr
#' The validation task.
task_valid = NULL,
#' @field loader_train ([`torch::dataloader`])\cr
#' The data loader for training.
loader_train = NULL,
#' @field loader_valid ([`torch::dataloader`])\cr
#' The data loader for validation.
loader_valid = NULL,
#' @field measures_train (`list()` of [`Measure`][mlr3::Measure]s)\cr
#' Measures used for training.
measures_train = NULL,
#' @field measures_valid (`list()` of [`Measure`][mlr3::Measure]s)\cr
#' Measures used for validation.
measures_valid = NULL,
#' @field network ([`torch::nn_module`])\cr
#' The torch network.
network = NULL,
#' @field optimizer ([`torch::optimizer`])\cr
#' The optimizer.
optimizer = NULL,
#' @field loss_fn ([`torch::nn_module`])\cr
#' The loss function.
loss_fn = NULL,
#' @field total_epochs (`integer(1)`)\cr
#' The total number of epochs the learner is trained for.
total_epochs = NULL,
#' @field last_scores_train (named `list()` or `NULL`)\cr
#' The scores from the last training batch. Names are the ids of the training measures.
#' If [`LearnerTorch`] sets `eval_freq` different from `1`, this is `NULL` in all epochs
#' that don't evaluate the model.
last_scores_train = NULL,
#' @field last_scores_valid (`list()`)\cr
#' The scores from the last validation batch. Names are the ids of the validation measures.
#' If [`LearnerTorch`] sets `eval_freq` different from `1`, this is `NULL` in all epochs
#' that don't evaluate the model.
last_scores_valid = NULL,
#' @field last_loss (`numeric(1)`)\cr
#' The loss from the last trainings batch.
last_loss = NULL,
#' @field epoch (`integer(1)`)\cr
#' The current epoch.
epoch = NULL,
#' @field step (`integer(1)`)\cr
#' The current iteration.
step = NULL,
#' @field prediction_encoder (`function()`)\cr
#' The learner's prediction encoder.
prediction_encoder = NULL,
#' @field batch (named `list()` of `torch_tensor`s)\cr
#' The current batch.
batch = NULL,
#' @field terminate (`logical(1)`)\cr
#' If this field is set to `TRUE` at the end of an epoch, training stops.
terminate = NULL,
#' @field device (`torch::torch_device`)\cr
#' The device.
device = NULL
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.