tests/testthat/test_TorchOptimizer.R

test_that("Basic checks", {
  torch_opt = TorchOptimizer$new(
    torch_optimizer = optim_ignite_adam,
    label = "Adam",
    packages = "mypackage"
  )
  expect_equal(torch_opt$id, "optim_ignite_adam")
  expect_r6(torch_opt, "TorchOptimizer")
  expect_set_equal(torch_opt$packages, c("mypackage", "torch", "mlr3torch"))
  expect_equal(torch_opt$label, "Adam")
  expect_set_equal(torch_opt$param_set$ids(), setdiff(formalArgs(optim_ignite_adam), "params"))
  expect_error(torch_opt$generate(), regexp = "could not be loaded: mypackage", fixed = TRUE)

  expect_error(
    TorchOptimizer$new(
      torch_optimizer = optim_sgd,
      param_set = ps(lr = p_uty(), params = p_uty())
    ),
    regexp = "The name 'params' is reserved for the network parameters.", fixed = TRUE
  )

  expect_error(TorchOptimizer$new(optim_adam, id = "mse", param_set = ps(par = p_uty())),
    regexp = "Parameter values with ids 'par' are missing in generator.", fixed = TRUE
  )

  torch_opt1 = TorchOptimizer$new(
    torch_optimizer = optim_sgd,
    label = "Stochastic Gradient Descent",
    id = "Sgd"
  )

  torch_opt1$param_set$set_values(lr = 0.9191)
  expect_set_equal(torch_opt1$packages, c("torch", "mlr3torch"))
  expect_equal(torch_opt1$label, "Stochastic Gradient Descent")
  expect_equal(torch_opt1$id, "Sgd")
  expect_equal(torch_opt1$param_set$values$lr, 0.9191)

  opt = torch_opt1$generate(nn_linear(1, 1)$parameters)
  expect_class(opt, "torch_optimizer")
  expect_equal(opt$defaults$lr, 0.9191)

  torch_opt2 = TorchOptimizer$new(
    torch_optimizer = optim_sgd,
    param_set = ps(lr = p_uty())
  )
  expect_equal(torch_opt2$param_set$ids(), "lr")
})


test_that("dictionary retrieval works", {
  torch_opt = t_opt("adam", lr = 0.99)
  expect_r6(torch_opt, "TorchOptimizer")
  expect_class(torch_opt$generator, "optim_ignite_adam")
  expect_equal(torch_opt$param_set$values$lr, 0.99)

  descriptors = t_opts(c("adam", "sgd"))
  expect_list(descriptors, types = "TorchOptimizer")
  expect_identical(ids(descriptors), c("adam", "sgd"))

  expect_class(t_opt(), "DictionaryMlr3torchOptimizers")
  expect_class(t_opts(), "DictionaryMlr3torchOptimizers")
})


test_that("dictionary can be converted to a table", {
  tbl = as.data.table(mlr3torch_optimizers)
  expect_data_table(tbl, ncols = 3, key = "key")
  expect_equal(colnames(tbl), c("key", "label", "packages"))

})

test_that("Cloning works", {
  torch_opt1 = t_opt("adam")
  torch_opt2 = torch_opt1$clone(deep = TRUE)
  expect_deep_clone(torch_opt1, torch_opt2)
})

test_that("Printer works", {
  observed = capture.output(print(t_opt("adam")))
  expected = c(
   "<TorchOptimizer:adam> Adaptive Moment Estimation",
   "* Generator: optim_ignite_adam",
   "* Parameters: list()",
   "* Packages: torch,mlr3torch"
  )
  expect_identical(observed, expected)
})


test_that("Converters are correctly implemented", {
  expect_r6(as_torch_optimizer("adam"), "TorchOptimizer")
  torch_opt = as_torch_optimizer(optim_adam)
  expect_r6(torch_opt, "TorchOptimizer")
  expect_equal(torch_opt$id, "optim_adam")
  expect_equal(torch_opt$label, "optim_adam")

  torch_opt1 = as_torch_optimizer(torch_opt, clone = TRUE)
  expect_deep_clone(torch_opt, torch_opt1)

  torch_op2 = as_torch_optimizer(optim_adam, id = "myopt", label = "Custom",
    man = "my_opt", param_set = ps(lr = p_uty(tags = "train"))
  )
  expect_r6(torch_op2, "TorchOptimizer")
  expect_equal(torch_op2$id, "myopt")
  expect_equal(torch_op2$label, "Custom")
  expect_equal(torch_op2$man, "my_opt")
  expect_equal(torch_op2$param_set$ids(), "lr")


  torch_opt3 = as_torch_optimizer(optim_adam)
  expect_equal(torch_opt3$id, "optim_adam")
  expect_equal(torch_opt3$label, "optim_adam")
})


test_that("Parameter test: adam", {
  torch_opt = t_opt("adam")
  param_set = torch_opt$param_set
  fn = torch_opt$generator
  res = expect_paramset(param_set, fn, exclude = "params")
  expect_paramtest(res)
})

test_that("Parameter test: sgd", {
  torch_opt = t_opt("sgd")
  param_set = torch_opt$param_set
  # lr is set to `optim_required()`
  fn = torch_opt$generator
  res = expect_paramset(param_set, fn, exclude = c("params", "lr"))
  expect_paramtest(res)
})

test_that("Parameter test: rmsprop", {
  torch_opt = t_opt("rmsprop")
  param_set = torch_opt$param_set
  fn = torch_opt$generator
  res = expect_paramset(param_set, fn, exclude = "params")
  expect_paramtest(res)
})

test_that("Parameter test: adagrad", {
  torch_opt = t_opt("adagrad")
  param_set = torch_opt$param_set
  fn = torch_opt$generator
  res = expect_paramset(param_set, fn, exclude = "params")
  expect_paramtest(res)
})

test_that("phash works", {
  expect_equal(t_opt("adam", lr = 2)$phash, t_opt("adam", lr = 1)$phash)
  expect_false(t_opt("sgd")$phash == t_opt("adam")$phash)
  expect_false(t_opt("sgd", id = "a")$phash == t_opt("adam", id = "b")$phash)
  expect_false(t_opt("sgd", label = "a")$phash == t_opt("adam", label = "b")$phash)
})

test_that("can train with every optimizer", {
  task = tsk("iris")$filter(1)
  test_optimizer = function(opt_id, ...) {
    opt = t_opt(opt_id, lr = 0.1)
    expect_learner(lrn("classif.mlp", optimizer = opt, batch_size = 1, epochs = 1)$train(task))
  }

  for (opt_id in names(mlr3torch_optimizers$items)) {
    test_optimizer(opt_id)
  }
})

Try the mlr3torch package in your browser

Any scripts or data that you put into this service are public.

mlr3torch documentation built on April 4, 2025, 3:03 a.m.