test_that("Can use mixed precision callback", {
x <- torch_randn(1000, 10, device=if(torch::cuda_is_available()) "cuda" else "cpu")
y <- torch_randn(1000, 1)
module <- nn_linear
model <- module %>%
setup(loss = nnf_mse_loss, optimizer = optim_adam) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 1e-4)
callback_for_testing <- luz_callback(
on_fit_begin = function() {
expect_true(!identical(ctx$step_opt, default_step_opt))
},
on_train_batch_begin = function() {
if (ctx$iter == 1 && ctx$epoch == 1) {
y <- torch_matmul(x, x$t())
if (torch::cuda_is_available())
expect_equal(y$dtype$.type(), "Half")
else
expect_equal(y$dtype$.type(), "BFloat16")
}
},
on_train_batch_before_backward = function() {
if (ctx$iter == 1 && ctx$epoch == 1) {
y <- torch_matmul(x, x$t())
expect_equal(y$dtype$.type(), "Float")
}
}
)
fitted <- model %>% fit(list(x, y), valid_data = 0.2, callbacks = list(
luz_callback_mixed_precision(enabled = cuda_is_available()),
callback_for_testing()
), accelerator = accelerator(cpu = !cuda_is_available()), verbose = FALSE)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.