tests/testthat/test-serialization.R

test_that("serialization works as expected", {

  torch::torch_manual_seed(1)
  set.seed(1)

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 1)

  path <- tempfile(pattern="rds")
  luz_save(output, path)

  test <- get_test_dl()

  expected_predictions <- predict(output, test)
  params <- output$model$state_dict()
  rm(output)
  gc()

  output <- luz_load(path)

  params2 <- output$model$state_dict()
  actual_predictions <- predict(output, test)

  expect_equal(
    as.array(actual_predictions$to(device="cpu")),
    as.array(expected_predictions$to(device="cpu"))
  )
})

test_that("serialization works when model uses `ctx` in forward", {

  model <- torch::nn_module(
    initialize = function() {
      self$fc <- torch::nn_linear(10, 1, bias = FALSE)
    },
    forward = function(x) {
      if (ctx$training)
        self$fc(torch::torch_ones_like(x))
      else
        self$fc(torch::torch_zeros_like(x))
    }
  )

  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  output <- mod %>%
    set_hparams() %>%
    fit(dl, verbose = FALSE, epochs = 1)


  pred <- predict(output, dl)
  expect_equal(as.array(pred$cpu()), as.array(torch::torch_zeros(100,1)))

  output$ctx$training <- TRUE

  path <- tempfile(pattern="rds")
  luz_save(output, path)

  rm(output)
  gc()

  output <- luz_load(path)
  pred <- predict(output, dl)
  expect_equal(
    as.array(pred$to(device="cpu")),
    as.array(torch::torch_zeros(100,1))
  )
})

test_that("save and load model weights", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
    )

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE, epochs = 1)

  pars <- lapply(output$model$parameters, function(x) x$clone())

  tmp <- tempfile(fileext = "rds")
  luz_save_model_weights(output, tmp)

  # modify parameters
  torch::with_no_grad({
    output$model$fc$weight$add_(5)
  })

  expect_not_equal_to_tensor(output$model$parameters$fc.weight, pars$fc.weight)

  #restore parameters
  luz_load_model_weights(output, tmp)

  pars2 <- output$model$parameters

  expect_equal_to_tensor(pars2$fc.weight, pars$fc.weight)
  expect_equal_to_tensor(pars2$fc.bias, pars$fc.bias)
})
mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.