library(mlr3torch)

The torch callback mechanism allows to customize the training loop of a neural network. While mlr3torch has some predefined callbacks for common use-cases, this vignette will show you how to write your own custom callback.

Building Blocks

At the heart of the callback mechanism are three R6 classes:

When using predefined callbacks, one usually only interacts with the TorchCallback class, which is constructed using t_clbk(<id>):

tc_hist = t_clbk("history")
tc_hist

The wrapped CallbackSet is accessible via the field $generator, and we can create a new instance by calling $new()

cbs_hist = tc_hist$generator$new()

Within the training loop of a torch model, different stages exist at which a callback can execute code. Below, we list all stages that are available, but they are also described in the documentation of CallbackSet.

mlr_reflections$torch$callback_stages

For the history callback, we see that it runs code at the beginning, before validation starts, and at the end of an epoch.

cbs_hist

Writing a Custom Logger

In order to define our own custom callback, we can use the helper torch_callback() helper function. As an example, we will create a custom logger that keeps track of an exponential moving average of the train loss and prints it at the end of every epoch. This callback takes one argument alpha which is the smoothing parameter and will store the moving average in self$moving_loss. The value alpha can later be configured in the Learner.

Then, we implement the main logic of the callback using two stages:

Finally, in order to make the final value of the moving average accessible in the Learner after the training finishes, we implement the

Note that we only do this here to have access to the final $moving_loss via the learner, but we would otherwise not have to implement these methods.

custom_logger = torch_callback("custom_logger",
  initialize = function(alpha = 0.1) {
    self$alpha = alpha
    self$moving_loss = NULL
  },
  on_batch_end = function() {
    if (is.null(self$moving_training_loss)) {
      self$moving_loss = self$ctx$last_loss
    } else {
      self$moving_loss = self$alpha * self$last_loss + (1 - self$alpha) * self$moving_loss
    }
  },
  on_before_valid = function() {
    cat(sprintf("Epoch %s: %.2f\n", self$ctx$epoch, self$moving_loss))
  },
  state_dict = function() {
    self$moving_loss
  },
  load_state_dict = function(state_dict) {
    self$moving_loss = state_dict
  }
)

This created a TorchCallback object and and associated CallbackSet R6-class for us:

custom_logger
custom_logger$generator

Using the Custom Logger

In order to showcase its effects, we create train a simple multi layer perceptron and pass it the callback. We set the smoothing parameter alpha to 0.05.

task = tsk("iris")
mlp = lrn("classif.mlp",
  callbacks = custom_logger, cb.custom_logger.alpha = 0.05,
  epochs = 5, batch_size = 64, neurons = 20)

mlp$train(task)

The information that is returnede by state_dict() is now accessible via the Learner's $model-slot:

mlp$model$callbacks$custom_logger


mlr-org/mlr3torch documentation built on April 17, 2025, 8:22 p.m.