interrupt <- luz_callback(
"interrupt",
weight = Inf,
failed = FALSE,
on_epoch_end = function() {
if (ctx$epoch == 5 && !self$failed) {
self$failed <- TRUE
self$metrics <- ctx$get_metrics_df()
stop("Error on epoch 5")
}
}
)
clone_tensors <- function(x) {
if (is.list(x))
lapply(x, clone_tensors)
else if (inherits(x, "torch_tensor"))
x$clone()
else
x
}
track_weights <- luz_callback(
"track_weights",
weights = list(),
opt = list(),
initialize = function(on_end = TRUE) {
self$on_end <- on_end
},
on_epoch_begin = function() {
self$weights[[ctx$epoch]] <- lapply(ctx$model$state_dict(), function(x) x$clone())
self$opt[[ctx$epoch]] <- clone_tensors(lapply(ctx$optimizers, function(opt) opt$state_dict()))
},
on_epoch_end = function() {
# this actually is only called when no saved model exists, otherwise
# the epoch is skipped by the autoresume callback.
if (self$on_end) {
self$on_epoch_begin()
}
}
)
test_that("resume a simple model", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
temp <- tempfile()
autoresume <- luz_callback_auto_resume(path = temp)
inter <- interrupt()
tr_w <- track_weights()
# simulate an error during training
expect_error(regexp = "Error on", {
results <- model %>% fit(
list(x, y),
callbacks = list(tr_w, autoresume, inter),
verbose = FALSE
)
})
tr_w_resume <- track_weights()
# reruning, now making sure no error will happen
results_resume <- model %>% fit(
list(x, y),
callbacks = list(tr_w_resume, autoresume, inter),
verbose = FALSE
)
metrics <- get_metrics(results_resume)
expect_true(nrow(metrics) == 10)
expect_true(all.equal(metrics[1:5,], inter$metrics))
# expect that the first five weights are identical to the last one from
# the first run.
for(i in 1:5) {
expect_true(torch_allclose(tr_w$weights[[5]]$weight, tr_w_resume$weights[[i]]$weight))
expect_true(torch_allclose(tr_w$weights[[5]]$bias, tr_w_resume$weights[[i]]$bias))
expect_identical(tr_w$opt[[i]], tr_w_resume$opt[[i]])
}
# Now that the run is complete, rerunning will trigger a completely new run.
results_resume2 <- model %>% fit(
list(x, y),
callbacks = list(autoresume),
verbose = FALSE
)
# we expect no identical metrics at all.
expect_true(!identical(get_metrics(results_resume2), metrics))
})
test_that("resume a model with more than one optimizer", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
module <- nn_module(
inherit = nn_linear,
set_optimizers = function(lr = 2*1e-4, betas = c(0.5, 0.999)) {
list(
weight = optim_adam(list(super$parameters$weight), lr = lr, betas = betas),
bias = optim_adam(list(super$parameters$bias), lr = lr, betas = betas)
)
}
)
model <- module %>%
setup(loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 1e-4)
temp <- tempfile()
autoresume <- luz_callback_auto_resume(path = temp)
tr_w <- track_weights()
inter <- interrupt()
# simulate an error during training
expect_error(regexp = "Error on", {
results <- model %>% fit(
list(x, y),
callbacks = list(tr_w, autoresume, inter),
verbose = FALSE
)
})
tr_w_resume <- track_weights()
results_resume <- model %>% fit(
list(x, y),
callbacks = list(tr_w_resume, autoresume, inter),
verbose = FALSE
)
for (i in 1:5) {
expect_recursive_equal(tr_w$opt[[5]], tr_w_resume$opt[[1]])
}
})
test_that("resume a model with learning rate scheduler", {
cb_with_state <- luz_callback(
weight = Inf,
initialize = function() {
self$i <- 1
},
on_epoch_end = function() {
self$i <- self$i + 1
},
state_dict = function() {
list(i = self$i)
},
load_state_dict = function(d) {
self$i <- d$i
}
)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
temp <- tempfile()
autoresume <- luz_callback_auto_resume(path = temp)
inter <- interrupt()
cb_state <- cb_with_state()
# simulate an error during training
expect_error(regexp = "Error on", {
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, cb_state, inter),
verbose = FALSE
)
})
cb_state2 <- cb_with_state()
results_resume <- model %>% fit(
list(x, y),
callbacks = list(autoresume, cb_state2, inter),
verbose = FALSE
)
# we would expect a larger number if the state is not correctly recovered
expect_equal(cb_state2$i, 10)
expect_equal(cb_state$i, 6)
})
test_that("resume works when model has been explicitly interrupted", {
# sometimes we want to early stop, in this case we need to make sure that
# this interruptions doesn't count as 'not finished training'.
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
temp <- tempfile()
autoresume <- luz_callback_auto_resume(path = temp)
early_stop <- luz_callback_early_stopping(monitor = "train_loss", patience = 1)
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, early_stop),
verbose = FALSE,
epochs = 100
)
results2 <- model %>% fit(
list(x, y),
callbacks = list(autoresume, early_stop),
verbose = FALSE,
epochs = 100
)
# values would be identical if results2 was resumed from results1
expect_true(get_metrics(results2)$value[1] != get_metrics(results)$value[1])
})
test_that("can use the resume_from callback", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
temp <- tempfile()
checkpoint <- luz_callback_model_checkpoint(
path = temp,
monitor = "train_loss"
)
tr <- track_weights()
result <- model %>% fit(
list(x, y),
callbacks = list(tr, checkpoint),
verbose = FALSE
)
tr2 <- track_weights(on_end = FALSE)
resume_from <- luz_callback_resume_from_checkpoint(path = temp)
result2 <- model %>% fit(
list(x, y),
callbacks = list(tr2, resume_from),
verbose = FALSE
)
expect_recursive_equal(
tr$weights[[10]],
tr2$weights[[1]]
)
})
test_that("resuming a model with a lr scheduler callback is correct", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
luz_callback_lr_progress <- luz_callback(
on_epoch_begin = function() {
rlang::inform(glue::glue("lr={ctx$opt$param_groups[[1]]$lr}"))
}
)
luz_callback_simulate_failure <- luz_callback(
initialize = function(at_epoch) {
self$at_epoch = at_epoch
},
on_epoch_begin = function() {
if (ctx$epoch>=self$at_epoch) rlang::abort("simulated failure")
}
)
autoresume <- luz_callback_auto_resume(path = tempfile())
expect_error(regexp = "simulated failure", {
result <- model %>% fit(
list(x, y),
callbacks = list(
autoresume,
luz_callback_lr_scheduler(lr_step,step_size=1L),
luz_callback_simulate_failure(at_epoch=5L),
luz_callback_lr_progress()
),
verbose = FALSE
)
})
expect_snapshot({
result <- model %>% fit(
list(x, y),
callbacks = list(
autoresume,
luz_callback_lr_scheduler(lr_step,step_size=1L),
luz_callback_simulate_failure(at_epoch=11L),
luz_callback_lr_progress()
),
verbose = FALSE
)
})
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.