luz_callback_auto_resume | R Documentation |
This callback allows you to resume training a model.
luz_callback_auto_resume(path = "./state.pt")
path |
Path to save state files for the model. |
When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.
By default model, optimizer state and records are serialized. Callbacks can
be used to customize serialization by implementing the state_dict()
and
load_state_dict()
methods.
If those methods are implemented, then state_dict()
is called at the end of
each epoch and load_state_dict()
is called when the model is resumed.
In general you will want to add this callback as the last in the callbacks
list, this way, the serialized state is likely to contain all possible changes
that other callbacks could have made at 'on_epoch_end'
. The default weight
attribute of this callback is Inf
.
Read the checkpointing article in the pkgdown website for more information.
Other luz_callbacks:
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
if (torch::torch_is_installed()) {
library(torch)
library(luz)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
"interrupt",
failed = FALSE,
on_epoch_end = function() {
if (ctx$epoch == 5 && !self$failed) {
self$failed <- TRUE
stop("Error on epoch 5")
}
}
)
path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()
# try once and the model fails
try({
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
})
# model resumes and completes
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
get_metrics(results)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.