tests/testthat/test-module.R

test_that("Fully managed", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  expect_s3_class(mod, "luz_module_generator")

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, valid_data = dl, verbose = FALSE)

  expect_s3_class(output, "luz_module_fitted")
})

test_that("Custom optimizer", {

  model <- get_model()
  model <- torch::nn_module(
    inherit = model,
    set_optimizers = function() {
      torch::optim_adam(self$parameters, lr = 0.01)
    }
  )
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
    )

  expect_s3_class(mod, "luz_module_generator")

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, valid_data = dl, verbose = FALSE)

  expect_s3_class(output, "luz_module_fitted")
})

test_that("Multiple optimizers", {

  model <- get_model()
  module <- torch::nn_module(
    initialize = function(input_size = 10, output_size = 1) {
      self$model1 = model(input_size, output_size)
      self$model2 <- model(input_size, output_size)
    },
    forward = function(x) {
      self$model1(x) + self$model2(x)
    },
    set_optimizers = function() {
      list(
        one = torch::optim_adam(self$model1$parameters, lr = 0.01),
        two = torch::optim_adam(self$model2$parameters, lr = 0.01)
      )
    }
  )
  dl <- get_dl()

  mod <- module %>%
    setup(
      loss = torch::nn_mse_loss(),
    )

  expect_s3_class(mod, "luz_module_generator")

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, valid_data = dl, verbose = FALSE)

  expect_s3_class(output, "luz_module_fitted")
})

test_that("can train without a validation dataset", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  expect_s3_class(mod, "luz_module_generator")

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE)

  expect_s3_class(output, "luz_module_fitted")
})

test_that("predict works for modules", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    fit(dl, verbose = FALSE)


  pred <- predict(output, dl)
  pred2 <- predict(output, dl)

  expect_equal(pred$shape, c(100, 1))
  expect_equal(as.array(pred$to(device = "cpu")), as.array(pred2$to(device="cpu")))

  # try with a different dataloader
  dl <- get_dl()
  pred <- predict(output, dl)
  pred2 <- predict(output, dl)
  expect_equal(as.array(pred$to(device = "cpu")), as.array(pred2$to(device="cpu")))

})

test_that("predict can use a progress bar", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  output <- mod %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    set_opt_hparams(lr = 0.001) %>%
    fit(dl, epochs = 1, verbose = FALSE)

  dl <- get_dl(len = 500)

  withr::with_options(
    list(luz.force_progress_bar = TRUE,
         luz.show_progress_bar_eta = FALSE,
         width = 80),
    {
      expect_snapshot(
        pred <- predict(output, dl, verbose=TRUE)
      )
    }
  )

  expect_equal(output$ctx$hparams$input_size, 10)
  expect_equal(output$ctx$opt_hparams$lr, 0.001)
})

test_that("valid_data works", {

  model <- get_model()
  model <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    ) %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    set_opt_hparams(lr = 0.001)

  fitted <- model %>% fit(
    list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
    epochs = 10,
    valid_data = 0.1,
    verbose = FALSE
  )

  expect_true("valid" %in% get_metrics(fitted)$set)

  expect_error(class= "value_error", regexp = "2", {
    model %>% fit(
      list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
      epochs = 10,
      valid_data = 2,
      verbose = FALSE
    )
  })

  expect_error(class= "value_error", regexp = "-1", {
    model %>% fit(
      list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
      epochs = 10,
      valid_data = -1,
      verbose = FALSE
    )
  })

  dl <- get_dl()
  expect_error(class= "value_error", regexp = "dataloader", {
    model %>% fit(
      dl,
      epochs = 10,
      valid_data = 0.2,
      verbose = FALSE
    )
  })

})

test_that("we can pass dataloader_options", {

  iter_callback <- luz_callback(
    initialize = function() {
      self$iter <- 0
    },
    on_train_batch_end = function() {
      self$iter <- self$iter + 1
    }
  )

  model <- get_model()
  model <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    ) %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    set_opt_hparams(lr = 0.001)

  x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

  iter <- iter_callback()
  fitted <- model %>% fit(
    x,
    epochs = 1,
    valid_data = 0.1,
    verbose = FALSE,
    dataloader_options = list(batch_size = 2, shuffle = FALSE),
    callbacks = iter
  )

  expect_equal(iter$iter, 45)

  dl <- get_dl()
  expect_error(regexp = "already a dataloader", {
    model %>% fit(
      dl,
      epochs = 1,
      verbose = FALSE,
      dataloader_options = list(batch_size = 2, shuffle = FALSE)
    )
  })

  expect_warning(regexp = "already a dataloader", {
    model %>% fit(
      x,
      epochs = 1,
      verbose = FALSE,
      valid_data = dl,
      dataloader_options = list(batch_size = 2, shuffle = FALSE)
    )
  })

  pred <- predict(fitted, x, dataloader_options = list(batch_size = 45, drop_last = TRUE))
  expect_tensor_shape(pred, c(90, 1))

  expect_warning(regexp = "already a dataloader", {
    predict(fitted, dl, dataloader_options = list(batch_size = 45, drop_last = TRUE))
  })

  expect_warning(regexp = "ignored for predictions", {
    predict(fitted, x, dataloader_options = list(shuffle = TRUE))
  })

})

test_that("evaluate works", {

  set.seed(1)
  torch_manual_seed(1)

  model <- get_model()
  model <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
      metrics = list(
        luz_metric_mae(),
        luz_metric_mse(),
        luz_metric_rmse()
      )
    ) %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    set_opt_hparams(lr = 0.001)

  x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

  fitted <- model %>% fit(
    x,
    epochs = 1,
    verbose = FALSE,
    dataloader_options = list(batch_size = 2, shuffle = FALSE)
  )

  e <- evaluate(fitted, x)

  expect_equal(nrow(get_metrics(e)), 4)
  expect_equal(ncol(get_metrics(e)), 2)

  expect_snapshot(print(e))
})

test_that("cutom backward", {

  model <- get_model()
  dl <- get_dl()

  mod <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
      backward = function(x) {
        x$backward()
        if (ctx$iter == 1 && ctx$epoch == 1) {
          print("hello")
        }
      }
    )

  expect_s3_class(mod, "luz_module_generator")

  expect_output(regexp = "hello", {
    output <- mod %>%
      set_hparams(input_size = 10, output_size = 1) %>%
      fit(dl, valid_data = dl, verbose = FALSE)
  })
  expect_s3_class(output, "luz_module_fitted")
})

test_that("luz module has a device arg", {

  mod <- get_model() %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  modul <- mod(1,1)

  expect_true(
    modul$device == torch_device("cpu")
  )

  mod <- nn_module(
    initialize = function() {
      self$par <- torch::nn_parameter(torch_randn(10, 10))
    },
    forward = function(x) {
      self$par
    },
    active = list(
      device = function() {
        "hello"
      }
    )
  )

  model <- mod %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam
    )

  modul <- mod()
  expect_equal(modul$device, "hello")

})

test_that("evaluate allows setting metrics for it", {

  model <- get_model()
  model <- model %>%
    setup(
      loss = torch::nn_mse_loss(),
      optimizer = torch::optim_adam,
      metrics = list(
        luz_metric_mae(),
        luz_metric_mse(),
        luz_metric_rmse()
      )
    ) %>%
    set_hparams(input_size = 10, output_size = 1) %>%
    set_opt_hparams(lr = 0.001)

  x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

  fitted <- model %>% fit(
    x,
    epochs = 1,
    verbose = FALSE,
    dataloader_options = list(batch_size = 2, shuffle = FALSE)
  )

  e1 <- get_metrics(evaluate(fitted, x))
  e2 <- get_metrics(evaluate(fitted, x, metrics = list(luz_metric_mae())))
  e3 <- get_metrics(evaluate(fitted, x))

  expect_equal(e1, e3)
  expect_equal(e2, e1[e1$metric %in% c("loss", "mae"),])

})
mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.