library(mlr3torch)
The torch callback mechanism allows to customize the training loop of a neural network.
While mlr3torch
has some predefined callbacks for common use-cases, this vignette will show you how to write your own custom callback.
At the heart of the callback mechanism are three R6
classes:
CallbackSet
contains the methods that are executed at different stages of the training loopTorchCallback
wraps a CallbackSet
and annotates it with meta information, including a ParamSet
ContextTorch
defines which information of the training process the CallbackSet
has access toWhen using predefined callbacks, one usually only interacts with the TorchCallback
class, which is constructed using t_clbk(<id>)
:
tc_hist = t_clbk("history") tc_hist
The wrapped CallbackSet
is accessible via the field $generator
, and we can create a new instance by calling $new()
cbs_hist = tc_hist$generator$new()
Within the training loop of a torch model, different stages exist at which a callback can execute code.
Below, we list all stages that are available, but they are also described in the documentation of CallbackSet
.
mlr_reflections$torch$callback_stages
For the history callback, we see that it runs code at the beginning, before validation starts, and at the end of an epoch.
cbs_hist
In order to define our own custom callback, we can use the helper torch_callback()
helper function.
As an example, we will create a custom logger that keeps track of an exponential moving average of the train loss and prints it at the end of every epoch.
This callback takes one argument alpha
which is the smoothing parameter and will store the moving average in self$moving_loss
.
The value alpha
can later be configured in the Learner
.
Then, we implement the main logic of the callback using two stages:
on_batch_end()
, which is called after the optimizer updates the network parameters using a mini-batch.
Here, we access the last loss via the ContextTorch
's (accessible via self$ctx
) $last_loss
field and update the value self$moving_loss
.on_before_valid()
, which is run before the validation loop. At this point we simply print the exponential moving average of the training loss.Finally, in order to make the final value of the moving average accessible in the Learner
after the training finishes, we implement the
state_dict()
method to return this value and theload_state_dict(state_dict)
method, that takes a previously retrieved state dict and sets it in the callback.Note that we only do this here to have access to the final $moving_loss
via the learner, but we would otherwise not have to implement these methods.
custom_logger = torch_callback("custom_logger", initialize = function(alpha = 0.1) { self$alpha = alpha self$moving_loss = NULL }, on_batch_end = function() { if (is.null(self$moving_training_loss)) { self$moving_loss = self$ctx$last_loss } else { self$moving_loss = self$alpha * self$last_loss + (1 - self$alpha) * self$moving_loss } }, on_before_valid = function() { cat(sprintf("Epoch %s: %.2f\n", self$ctx$epoch, self$moving_loss)) }, state_dict = function() { self$moving_loss }, load_state_dict = function(state_dict) { self$moving_loss = state_dict } )
This created a TorchCallback
object and and associated CallbackSet
R6-class for us:
custom_logger
custom_logger$generator
In order to showcase its effects, we create train a simple multi layer perceptron and pass it the callback.
We set the smoothing parameter alpha
to 0.05.
task = tsk("iris") mlp = lrn("classif.mlp", callbacks = custom_logger, cb.custom_logger.alpha = 0.05, epochs = 5, batch_size = 64, neurons = 20) mlp$train(task)
The information that is returnede by state_dict()
is now accessible via the Learner
's $model
-slot:
mlp$model$callbacks$custom_logger
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.