knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(luz) library(torch)
Luz is a high-level API for torch that aims to encapsulate the training loop into a set of reusable pieces of code.
Luz reduces the boilerplate code required to train a model with torch and avoids the error prone zero_grad()
- backward()
- step()
sequence of calls, and also simplifies the process of moving data and models between CPUs and GPUs.
Luz is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control you need for your training loop.
Luz is heavily inspired by other higher level frameworks for deep learning, to cite a few:
FastAI: we are heavily inspired by the FastAI library, especially the Learner
object and the callbacks API.
Keras: We are also heavily inspired by Keras, especially callback names. The lightning module interface is similar to compile
, too.
PyTorch Lightning: The idea of the luz_module
being a subclass of nn_module
is inspired by the LightningModule
object in lightning.
HuggingFace Accelerate: The internal device placement API is heavily inspired by Accelerate, but is much more modest in features. Currently only CPU and Single GPU are supported.
nn_module
As much as possible, luz tries to reuse the existing structures from torch. A model in luz is defined identically as you would define it if using raw torch. For a specific example, this is the definition of a feed-forward CNN that can be used to classify digits from the MNIST dataset:
net <- nn_module( "Net", initialize = function(num_class) { self$conv1 <- nn_conv2d(1, 32, 3, 1) self$conv2 <- nn_conv2d(32, 64, 3, 1) self$dropout1 <- nn_dropout2d(0.25) self$dropout2 <- nn_dropout2d(0.5) self$fc1 <- nn_linear(9216, 128) self$fc2 <- nn_linear(128, num_class) }, forward = function(x) { x <- self$conv1(x) x <- nnf_relu(x) x <- self$conv2(x) x <- nnf_relu(x) x <- nnf_max_pool2d(x, 2) x <- self$dropout1(x) x <- torch_flatten(x, start_dim = 2) x <- self$fc1(x) x <- nnf_relu(x) x <- self$dropout2(x) x <- self$fc2(x) x } )
We can now train this model in the train_dl
and validate it in the test_dl
torch::dataloaders()
with:
fitted <- net %>% setup( loss = nn_cross_entropy_loss(), optimizer = optim_adam, metrics = list( luz_metric_accuracy ) ) %>% set_hparams(num_class = 10) %>% set_opt_hparams(lr = 0.003) %>% fit(train_dl, epochs = 10, valid_data = test_dl)
Let's understand what happens in this chunk of code:
setup
function allows you to configure the loss (objective) function and the optimizer that you will use to train your model. Optionally you can pass a list of metrics that are tracked during the training procedure. Note: the loss function can be any function taking input
and target
tensors and returning a scalar tensor value, and the optimizer can be any core torch optimizer or custom ones created with the torch::optimizer()
function.set_hparams()
function allows you to set hyper-parameters that should be passed to the module initialize()
method. For example in this case we pass num_classes = 10
.set_opt_hparams()
function allows you to pass hyper-parameters that are used by the optimizer function. For example, optim_adam()
can take the lr
parameter specifying the learning rate and we specify it with lr = 0.003
.fit
method will take the model specification provided by setup()
and run the training procedure using the specified training and validation torch::dataloaders()
as well as the number of epochs. Note: we again reuse core torch data structures, instead of providing our own data loading functionality.fitted
contains the trained model as well as the record of metrics and losses produced during training. It can also be used for producing predictions and for evaluating the trained model on other datasets.When fitting, luz will use the fastest possible accelerator; if a CUDA-capable GPU is available it will be used, otherwise we fall back to the CPU. It also automatically moves data, optimizers, and models to the selected device so you don't need to handle it manually (which is in general very error prone).
To create predictions from the trained model you can use the predict
method:
predictions <- predict(fitted, test_dl)
You now have a general idea of how to use the fit
function and now it's important to have an overview of what's happening inside it.
In pseudocode, here's what fit
does.
This is not fully detailed but should help you to build your intuition:
# -> Initialize objects: model, optimizers. # -> Select fitting device. # -> Move data, model, optimizers to the selected device. # -> Start training for (epoch in 1:epochs) { # -> Training procedure for (batch in train_dl) { # -> Calculate model `forward` method. # -> Calculate the loss # -> Update weights # -> Update metrics and tracking loss } # -> Validation procedure for (batch in valid_dl) { # -> Calculate model `forward` method. # -> Calculate the loss # -> Update metrics and tracking loss } } # -> End training
One of the most important parts in machine learning projects is choosing the evaluation metric. Luz allows tracking many different metrics during training with minimal code changes.
In order to track metrics, you only need to modify the metrics
parameter in the setup
function:
fitted <- net %>% setup( ... metrics = list( luz_metric_accuracy ) ) %>% fit(...)
Luz provides implementations of a few of the most used metrics.
If a metric is not available you can always implement a new one using the luz_metric
function.
Luz provides different ways to customize the training progress depending on the level of control you need in the training loop. The fastest way and the more 'reusable', in the sense that you can create training modifications that can be used in many different situations, is via callbacks.
The training loop in luz has many breakpoints that can call arbitrary R functions. This functionality allows you to customize the training process without having to modify the general training logic.
Luz implements 3 default callbacks that occur in every training procedure:
train-eval callback: Sets the model to train()
or eval()
depending on if the procedure is doing training or validation.
metrics callback: evaluate metrics during training and validation process.
progress callback: implements a progress bar and prints progress information during training.
You can also implement custom callbacks that modify or act specifically for your training procedure. For example:
Attributes in ctx
can be used to produce the desired behavior of callbacks.
You can find information about the context object using help("ctx")
.
In our example, we use the ctx$iter
attribute to print the iteration number for each training batch.
In this article you learned how to train your first model using luz and the basics of customization using both custom metrics and callbacks.
Luz also allows more flexible modifications of the training loop described in vignette("custom-loop")
.
You should now be able to follow the examples marked with the 'basic' category in the examples gallery.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.