tests/testthat/test-callbacks.R

test_that("callback lr scheduler", {

  skip_on_os("windows")

  torch::torch_manual_seed(1)
  set.seed(1)

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  expect_snapshot({
    expect_message({
      output <- mod %>%
        set_hparams(input_size = 10, output_size = 1) %>%
        fit(dl, verbose = FALSE, epochs = 5, callbacks = list(
          luz_callback_lr_scheduler(torch::lr_multiplicative, verbose = TRUE,
                                    lr_lambda = function(epoch) 0.5)
        ))
    })
  })

  expect_snapshot({
    expect_message({
      output <- mod %>%
        set_hparams(input_size = 10, output_size = 1) %>%
        fit(dl, verbose = FALSE, epochs = 20, callbacks = list(
          luz_callback_lr_scheduler(
            torch::lr_reduce_on_plateau,
            verbose = TRUE,
            patience = 2,
            threshold = 1e-1
          )
        ))
    })
  })

})

test_that("csv callback", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  tmp <- tempfile()

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 5, callbacks = list(
      luz_callback_csv_logger(tmp)
    ))

  x <- read.table(tmp, header = TRUE, sep = ",")
  expect_equal(nrow(x), 5)
  expect_equal(names(x), c("epoch", "set", "loss"))

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 5, valid_data = dl, callbacks = list(
      luz_callback_csv_logger(tmp)
    ))

  x <- read.table(tmp, header = TRUE, sep = ",")

  expect_equal(nrow(x), 10)
  expect_equal(names(x), c("epoch", "set", "loss"))

})

test_that("progressbar appears with training and validation", {

  torch::torch_manual_seed(1)
  set.seed(1)

  model <- get_model()
  dl <- get_test_dl(len = 500)

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  withr::with_options(list(luz.force_progress_bar = TRUE,
                           luz.show_progress_bar_eta = FALSE,
                           width = 80), {
    expect_snapshot({
      expect_message({
        output <- mod %>%
          set_hparams(input_size = 10, output_size = 1) %>%
          fit(dl, verbose = TRUE, epochs = 2, valid_data = dl)
      })
    })
  })

})

test_that("gradient clip works correctly", {

  model <- get_model()
  dl <- get_test_dl(len = 500)

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 2, valid_data = dl,
        callbacks = list(luz_callback_gradient_clip(max_norm = 0)))

  # we expect that no learning happened thus the loss is identicall
  # acrosss all metrics.
  expect_length(unique(get_metrics(output)$value), 1)
  expect_length(get_metrics(output)$value, 4)

  expect_error(luz_callback_gradient_clip(max_norm = "a"), "max_norm")
  expect_error(luz_callback_gradient_clip(norm_type = "a"), "norm_type")
})

test_that("improve error message when you provide a unitinitilized callback", {

  skip_on_os("windows")

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  expect_snapshot_error({
    model %>% fit(list(x, y), callbacks = list(luz_callback_auto_resume))
  })

})

test_that("can get progress when using iterable datasets", {

  torch::torch_manual_seed(1)
  set.seed(1)

  model <- get_model()

  get_iterable_ds <- torch::iterable_dataset(
    "iterable_ds",
    initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) {
      self$len <- len
      self$x <- torch::torch_randn(size = c(len, x_size))
      self$y <- torch::torch_randn(size = c(len, y_size))
    },
    .iter = function() {
      i <- 0
      function() {
        i <<- i + 1

        if (i > self$len) {
          return(coro::exhausted())
        }

        list(
          x = self$x[i,..],
          y = self$y[i,..]
        )
      }
    }
  )

  ds <- get_iterable_ds()
  dl <- torch::dataloader(ds, batch_size = 32)

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )


  withr::with_options(list(
    luz.force_progress_bar = TRUE,
    luz.show_progress_bar_eta = FALSE,
    width = 80), {

      expect_snapshot({
        expect_message({
          output <- mod %>%
            set_hparams(input_size = 10, output_size = 1) %>%
            fit(
              get_iterable_ds(),
              verbose = TRUE,
              epochs = 2,
              valid_data = get_iterable_ds(),
            )
        })
      })
    })



})
mlverse/torchlight documentation built on Sept. 19, 2024, 11:22 p.m.