mlr_callback_set | R Documentation |
Base class from which callbacks should inherit (see section Inheriting). A callback set is a collection of functions that are executed at different stages of the training loop. They can be used to gain more control over the training process of a neural network without having to write everything from scratch.
When used a in torch learner, the CallbackSet
is wrapped in a TorchCallback
.
The latters parameter set represents the arguments of the CallbackSet
's $initialize()
method.
For each available stage (see section Stages) a public method $on_<stage>()
can be defined.
The evaluation context (a ContextTorch
) can be accessed via self$ctx
, which contains
the current state of the training loop.
This context is assigned at the beginning of the training loop and removed afterwards.
Different stages of a callback can communicate with each other by assigning values to $self
.
State:
To be able to store information in the $model
slot of a LearnerTorch
, callbacks support a state API.
You can overload the $state_dict()
public method to define what will be stored in learner$model$callbacks$<id>
after training finishes.
This then also requires to implement a $load_state_dict(state_dict)
method that defines how to load a previously saved
callback state into a different callback.
Note that the $state_dict()
should not include the parameter values that were used to initialize the callback.
For creating custom callbacks, the function torch_callback()
is recommended, which creates a
CallbackSet
and then wraps it in a TorchCallback
.
To create a CallbackSet
the convenience function callback_set()
can be used.
These functions perform checks such as that the stages are not accidentally misspelled.
begin
:: Run before the training loop begins.
epoch_begin
:: Run he beginning of each epoch.
batch_begin
:: Run before the forward call.
after_backward
:: Run after the backward call.
batch_end
:: Run after the optimizer step.
batch_valid_begin
:: Run before the forward call in the validation loop.
batch_valid_end
:: Run after the forward call in the validation loop.
valid_end
:: Run at the end of validation.
epoch_end
:: Run at the end of each epoch.
end
:: Run after last epoch.
exit
:: Run at last, using on.exit()
.
If training is to be stopped, it is possible to set the field $terminate
of ContextTorch
.
At the end of every epoch this field is checked and if it is TRUE
, training stops.
This can for example be used to implement custom early stopping.
ctx
(ContextTorch
or NULL
)
The evaluation context for the callback.
This field should always be NULL
except during the $train()
call of the torch learner.
stages
(character()
)
The active stages of this callback set.
print()
Prints the object.
CallbackSet$print(...)
...
(any)
Currently unused.
state_dict()
Returns information that is kept in the the LearnerTorch
's state after training.
This information should be loadable into the callback using $load_state_dict()
to be able to continue training.
This returns NULL
by default.
CallbackSet$state_dict()
load_state_dict()
Loads the state dict into the callback to continue training.
CallbackSet$load_state_dict(state_dict)
state_dict
(any)
The state dict as retrieved via $state_dict()
.
clone()
The objects of this class are cloneable with this method.
CallbackSet$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_callback_set.tb
,
mlr_callback_set.unfreeze
,
mlr_context_torch
,
t_clbk()
,
torch_callback()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.