tests/testthat/test-callbacks-resume.R

interrupt <- luz_callback(
  "interrupt",
  weight = Inf,
  failed = FALSE,
  on_epoch_end = function() {
    if (ctx$epoch == 5 && !self$failed) {
      self$failed <- TRUE
      self$metrics <- ctx$get_metrics_df()
      stop("Error on epoch 5")
    }
  }
)

clone_tensors <- function(x) {
  if (is.list(x))
    lapply(x, clone_tensors)
  else if (inherits(x, "torch_tensor"))
    x$clone()
  else
    x
}

track_weights <- luz_callback(
  "track_weights",
  weights = list(),
  opt = list(),
  initialize = function(on_end = TRUE) {
    self$on_end <- on_end
  },
  on_epoch_begin = function() {
    self$weights[[ctx$epoch]] <- lapply(ctx$model$state_dict(), function(x) x$clone())
    self$opt[[ctx$epoch]] <- clone_tensors(lapply(ctx$optimizers, function(opt) opt$state_dict()))
  },
  on_epoch_end = function() {
    # this actually is only called when no saved model exists, otherwise
    # the epoch is skipped by the autoresume callback.
    if (self$on_end) {
      self$on_epoch_begin()
    }
  }
)

test_that("resume a simple model", {

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  temp <- tempfile()
  autoresume <- luz_callback_auto_resume(path = temp)
  inter <- interrupt()
  tr_w <- track_weights()

  # simulate an error during training
  expect_error(regexp = "Error on", {
    results <- model %>% fit(
      list(x, y),
      callbacks = list(tr_w, autoresume, inter),
      verbose = FALSE
    )
  })

  tr_w_resume <- track_weights()
  # reruning, now making sure no error will happen
  results_resume <- model %>% fit(
    list(x, y),
    callbacks = list(tr_w_resume, autoresume, inter),
    verbose = FALSE
  )

  metrics <- get_metrics(results_resume)

  expect_true(nrow(metrics) == 10)
  expect_true(all.equal(metrics[1:5,], inter$metrics))

  # expect that the first five weights are identical to the last one from
  # the first run.
  for(i in 1:5) {
    expect_true(torch_allclose(tr_w$weights[[5]]$weight, tr_w_resume$weights[[i]]$weight))
    expect_true(torch_allclose(tr_w$weights[[5]]$bias, tr_w_resume$weights[[i]]$bias))

    expect_identical(tr_w$opt[[i]], tr_w_resume$opt[[i]])
  }

  # Now that the run is complete, rerunning will trigger a completely new run.
  results_resume2 <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume),
    verbose = FALSE
  )

  # we expect no identical metrics at all.
  expect_true(!identical(get_metrics(results_resume2), metrics))
})

test_that("resume a model with more than one optimizer", {

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  module <- nn_module(
    inherit = nn_linear,
    set_optimizers = function(lr = 2*1e-4, betas = c(0.5, 0.999)) {
      list(
        weight = optim_adam(list(super$parameters$weight), lr = lr, betas = betas),
        bias = optim_adam(list(super$parameters$bias), lr = lr, betas = betas)
      )
    }
  )

  model <- module %>%
    setup(loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 1e-4)

  temp <- tempfile()
  autoresume <- luz_callback_auto_resume(path = temp)
  tr_w <- track_weights()
  inter <- interrupt()

  # simulate an error during training
  expect_error(regexp = "Error on", {
    results <- model %>% fit(
      list(x, y),
      callbacks = list(tr_w, autoresume, inter),
      verbose = FALSE
    )
  })

  tr_w_resume <- track_weights()
  results_resume <- model %>% fit(
    list(x, y),
    callbacks = list(tr_w_resume, autoresume, inter),
    verbose = FALSE
  )

  for (i in 1:5) {
    expect_recursive_equal(tr_w$opt[[5]], tr_w_resume$opt[[1]])
  }

})

test_that("resume a model with learning rate scheduler", {
  cb_with_state <- luz_callback(
    weight = Inf,
    initialize = function() {
      self$i <- 1
    },
    on_epoch_end = function() {
      self$i <- self$i + 1
    },
    state_dict = function() {
      list(i = self$i)
    },
    load_state_dict = function(d) {
      self$i <- d$i
    }
  )


  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  temp <- tempfile()
  autoresume <- luz_callback_auto_resume(path = temp)
  inter <- interrupt()
  cb_state <- cb_with_state()

  # simulate an error during training
  expect_error(regexp = "Error on", {
    results <- model %>% fit(
      list(x, y),
      callbacks = list(autoresume, cb_state, inter),
      verbose = FALSE
    )
  })

  cb_state2 <- cb_with_state()
  results_resume <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, cb_state2, inter),
    verbose = FALSE
  )

  # we would expect a larger number if the state is not correctly recovered
  expect_equal(cb_state2$i, 10)
  expect_equal(cb_state$i, 6)
})


test_that("resume works when model has been explicitly interrupted", {
  # sometimes we want to early stop, in this case we need to make sure that
  # this interruptions doesn't count as 'not finished training'.

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  temp <- tempfile()
  autoresume <- luz_callback_auto_resume(path = temp)
  early_stop <- luz_callback_early_stopping(monitor = "train_loss", patience = 1)

  results <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, early_stop),
    verbose = FALSE,
    epochs = 100
  )

  results2 <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, early_stop),
    verbose = FALSE,
    epochs = 100
  )

  # values would be identical if results2 was resumed from results1
  expect_true(get_metrics(results2)$value[1] != get_metrics(results)$value[1])
})

test_that("can use the resume_from callback", {

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  temp <- tempfile()
  checkpoint <- luz_callback_model_checkpoint(
    path = temp,
    monitor = "train_loss"
  )

  tr <- track_weights()
  result <- model %>% fit(
    list(x, y),
    callbacks = list(tr, checkpoint),
    verbose = FALSE
  )

  tr2 <- track_weights(on_end = FALSE)
  resume_from <- luz_callback_resume_from_checkpoint(path = temp)
  result2 <- model %>% fit(
    list(x, y),
    callbacks = list(tr2, resume_from),
    verbose = FALSE
  )

  expect_recursive_equal(
    tr$weights[[10]],
    tr2$weights[[1]]
  )
})

test_that("resuming a model with a lr scheduler callback is correct", {

  x <- torch_randn(1000, 10)
  y <- torch_randn(1000, 1)

  model <- nn_linear %>%
    setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
    set_hparams(in_features = 10, out_features = 1) %>%
    set_opt_hparams(lr = 0.01)

  luz_callback_lr_progress <- luz_callback(
    on_epoch_begin = function() {
      rlang::inform(glue::glue("lr={ctx$opt$param_groups[[1]]$lr}"))
    }
  )

  luz_callback_simulate_failure <- luz_callback(
    initialize = function(at_epoch) {
      self$at_epoch = at_epoch
    },
    on_epoch_begin = function() {
      if (ctx$epoch>=self$at_epoch) rlang::abort("simulated failure")
    }
  )

  autoresume <- luz_callback_auto_resume(path = tempfile())

  expect_error(regexp = "simulated failure", {
    result <- model %>% fit(
      list(x, y),
      callbacks = list(
        autoresume,
        luz_callback_lr_scheduler(lr_step,step_size=1L),
        luz_callback_simulate_failure(at_epoch=5L),
        luz_callback_lr_progress()
      ),
      verbose = FALSE
    )
  })

  expect_snapshot({
    result <- model %>% fit(
      list(x, y),
      callbacks = list(
        autoresume,
        luz_callback_lr_scheduler(lr_step,step_size=1L),
        luz_callback_simulate_failure(at_epoch=11L),
        luz_callback_lr_progress()
      ),
      verbose = FALSE
    )
  })

})
mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.