knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(luz)
library(torch)

Luz is a high-level API for torch that aims to encapsulate the training loop into a set of reusable pieces of code. Luz reduces the boilerplate code required to train a model with torch and avoids the error prone zero_grad() - backward() - step() sequence of calls, and also simplifies the process of moving data and models between CPUs and GPUs. Luz is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control you need for your training loop.

Luz is heavily inspired by other higher level frameworks for deep learning, to cite a few:

Training a nn_module

As much as possible, luz tries to reuse the existing structures from torch. A model in luz is defined identically as you would define it if using raw torch. For a specific example, this is the definition of a feed-forward CNN that can be used to classify digits from the MNIST dataset:

net <- nn_module(
  "Net",
  initialize = function(num_class) {
    self$conv1 <- nn_conv2d(1, 32, 3, 1)
    self$conv2 <- nn_conv2d(32, 64, 3, 1)
    self$dropout1 <- nn_dropout2d(0.25)
    self$dropout2 <- nn_dropout2d(0.5)
    self$fc1 <- nn_linear(9216, 128)
    self$fc2 <- nn_linear(128, num_class)
  },
  forward = function(x) {
    x <- self$conv1(x)
    x <- nnf_relu(x)
    x <- self$conv2(x)
    x <- nnf_relu(x)
    x <- nnf_max_pool2d(x, 2)
    x <- self$dropout1(x)
    x <- torch_flatten(x, start_dim = 2)
    x <- self$fc1(x)
    x <- nnf_relu(x)
    x <- self$dropout2(x)
    x <- self$fc2(x)
    x
  }
)

We can now train this model in the train_dl and validate it in the test_dl torch::dataloaders() with:

fitted <- net %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy
    )
  ) %>%
  set_hparams(num_class = 10) %>% 
  set_opt_hparams(lr = 0.003) %>% 
  fit(train_dl, epochs = 10, valid_data = test_dl)

Let's understand what happens in this chunk of code:

  1. The setup function allows you to configure the loss (objective) function and the optimizer that you will use to train your model. Optionally you can pass a list of metrics that are tracked during the training procedure. Note: the loss function can be any function taking input and target tensors and returning a scalar tensor value, and the optimizer can be any core torch optimizer or custom ones created with the torch::optimizer() function.
  2. The set_hparams() function allows you to set hyper-parameters that should be passed to the module initialize() method. For example in this case we pass num_classes = 10.
  3. The set_opt_hparams() function allows you to pass hyper-parameters that are used by the optimizer function. For example, optim_adam() can take the lr parameter specifying the learning rate and we specify it with lr = 0.003.
  4. The fit method will take the model specification provided by setup() and run the training procedure using the specified training and validation torch::dataloaders() as well as the number of epochs. Note: we again reuse core torch data structures, instead of providing our own data loading functionality.
  5. The returned object fitted contains the trained model as well as the record of metrics and losses produced during training. It can also be used for producing predictions and for evaluating the trained model on other datasets.

When fitting, luz will use the fastest possible accelerator; if a CUDA-capable GPU is available it will be used, otherwise we fall back to the CPU. It also automatically moves data, optimizers, and models to the selected device so you don't need to handle it manually (which is in general very error prone).

To create predictions from the trained model you can use the predict method:

predictions <- predict(fitted, test_dl)

The training loop

You now have a general idea of how to use the fit function and now it's important to have an overview of what's happening inside it. In pseudocode, here's what fit does. This is not fully detailed but should help you to build your intuition:

# -> Initialize objects: model, optimizers.
# -> Select fitting device.
# -> Move data, model, optimizers to the selected device.
# -> Start training
for (epoch in 1:epochs) {
  # -> Training procedure
  for (batch in train_dl) {
    # -> Calculate model `forward` method.
    # -> Calculate the loss
    # -> Update weights
    # -> Update metrics and tracking loss
  }
  # -> Validation procedure
  for (batch in valid_dl) {
    # -> Calculate model `forward` method.
    # -> Calculate the loss
    # -> Update metrics and tracking loss
  }
}
# -> End training

Metrics

One of the most important parts in machine learning projects is choosing the evaluation metric. Luz allows tracking many different metrics during training with minimal code changes.

In order to track metrics, you only need to modify the metrics parameter in the setup function:

fitted <- net %>%
  setup(
    ...
    metrics = list(
      luz_metric_accuracy
    )
  ) %>%
  fit(...)

Luz provides implementations of a few of the most used metrics. If a metric is not available you can always implement a new one using the luz_metric function.


Evaluate


Customizing with callbacks

Luz provides different ways to customize the training progress depending on the level of control you need in the training loop. The fastest way and the more 'reusable', in the sense that you can create training modifications that can be used in many different situations, is via callbacks.

The training loop in luz has many breakpoints that can call arbitrary R functions. This functionality allows you to customize the training process without having to modify the general training logic.

Luz implements 3 default callbacks that occur in every training procedure:

You can also implement custom callbacks that modify or act specifically for your training procedure. For example:



Attributes in ctx can be used to produce the desired behavior of callbacks. You can find information about the context object using help("ctx"). In our example, we use the ctx$iter attribute to print the iteration number for each training batch.

Next steps

In this article you learned how to train your first model using luz and the basics of customization using both custom metrics and callbacks.

Luz also allows more flexible modifications of the training loop described in vignette("custom-loop").

You should now be able to follow the examples marked with the 'basic' category in the examples gallery.



mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.