tests/testthat/test-callbacks-amp.R

test_that("Can use mixed precision callback", {

  x <- torch_randn(1000, 10, device=if(torch::cuda_is_available()) "cuda" else "cpu")
  y <- torch_randn(1000, 1)

  module <- nn_linear

  model <- module %>%
    setup(loss = nnf_mse_loss, optimizer = optim_adam) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 1e-4)

  callback_for_testing <- luz_callback(
    on_fit_begin = function() {
      expect_true(!identical(ctx$step_opt, default_step_opt))
    },
    on_train_batch_begin = function() {
      if (ctx$iter == 1 && ctx$epoch == 1) {
	      y <- torch_matmul(x, x$t())
      	if (torch::cuda_is_available())
  	      expect_equal(y$dtype$.type(), "Half")
     	  else
  	      expect_equal(y$dtype$.type(), "BFloat16")
      }
    },
    on_train_batch_before_backward = function() {
      if (ctx$iter == 1 && ctx$epoch == 1) {
        y <- torch_matmul(x, x$t())
        expect_equal(y$dtype$.type(), "Float")
      }
    }
  )

  fitted <- model %>% fit(list(x, y), valid_data = 0.2, callbacks = list(
    luz_callback_mixed_precision(enabled = cuda_is_available()),
    callback_for_testing()
  ), accelerator = accelerator(cpu = !cuda_is_available()), verbose = FALSE)

})
mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.