tests/testthat/test-callbacks-interrupt.R

test_that("callback interrupt works", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  hello <- 0
  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 25, callbacks = list(
      luz_callback(on_epoch_begin = function() {
        if (ctx$epoch == 3)
          rlang::interrupt()
      })(),
      luz_callback(on_interrupt = function() {
        hello <<- 1
      })()
    ))

  expect_equal(length(output$records$metrics$train), 2)
  expect_equal(hello, 1)

})
mlverse/torchlight documentation built on Sept. 19, 2024, 11:22 p.m.