tests/testthat/test_PipeOpTorch.R

test_that("Basic checks", {
  task = tsk("german_credit")

  # basic checks that output is checked correctly
  obj = PipeOpTorchDebug$new(id = "debug", inname = paste0("input", 1:2), outname = paste0("output", 1:2))
  expect_pipeop(obj)
  expect_class(obj, "PipeOpTorch")

  expect_equal(unique(obj$input$train), "ModelDescriptor")
  expect_equal(unique(obj$output$train), "ModelDescriptor")

  expect_equal(unique(obj$input$predict), "Task")
  expect_equal(unique(obj$output$predict), "Task")
  expect_class(obj$module_generator, "nn_module_generator")
  expect_equal(obj$tags, "torch")
  expect_set_equal(obj$packages, c("mlr3torch", "torch", "mlr3pipelines"))
})

test_that("cloning works", {
  # can't use PipeOpTorchDebug because it then fails for some reason because of a  missing help page
  obj = PipeOpTorchReLU$new()
  obj1 = obj$clone(deep = TRUE)
  expect_deep_clone(obj, obj1)
})

test_that("single input and output", {
  task = tsk("iris")

  # train
  md = (po("torch_ingress_num") %>>%
    po("torch_optimizer") %>>%
    po("torch_loss", "cross_entropy") %>>%
    po("torch_callbacks", "checkpoint"))$train(task)[[1L]]

  obj = po("nn_linear", out_features = 10)

  mdout = obj$train(list(md))[[1L]]
  expect_identical(address(md$graph), address(mdout$graph))
  expect_true(!identical(md$pointer, mdout$pointer))
  expect_true(!identical(md$pointer_shape, mdout$pointer_shape))
  expect_equal(address(md$loss), address(mdout$loss))
  expect_equal(address(md$optimizer), address(mdout$optimizer))
  expect_equal(address(md$callbacks[[1L]]), address(mdout$callbacks[[1L]]))
  expect_equal(mdout$pointer, c("nn_linear", "output"))
  expect_equal(mdout$pointer_shape, c(NA, 10))
  expect_true(obj$is_trained)
  expect_true("nn_linear" %in% names(mdout$graph$pipeops))
  expect_class(mdout$graph$pipeops$nn_linear, "PipeOpModule")
  expect_class(mdout$graph$pipeops$nn_linear$module, "nn_linear")
  expect_equal(
    data.table(
      src_id = "torch_ingress_num",
      src_channel = "output",
      dst_id = "nn_linear",
      dst_channel = "input"
    ),
    mdout$graph$edges
  )

  # predict
  taskout = obj$predict(list(task))
  expect_identical(address(taskout[[1L]]), address(taskout[[1L]]))
})

test_that("train handles multiple input channels correctly", {
  task = tsk("iris")

  # first we start with vararg
  obj = po("nn_merge_sum")

  graph = as_graph(list(
    po("select_1", selector = selector_grep("Sepal")) %>>% po("torch_ingress_num_1"),
    po("select_2", selector = selector_grep("Petal")) %>>% po("torch_ingress_num_2"))
  )

  mds = graph$train(task)
  mdsout = obj$train(mds)
  expect_true(obj$is_trained)
  expect_equal(address(mdsout[[1L]]$graph), address(mdsout[[1L]]$graph))
  expect_equal(mdsout[[1L]]$pointer, c("nn_merge_sum", "output"))
  expect_equal(mdsout[[1L]]$pointer_shape, c(NA, 2))

  expect_equal(
    data.table(
      src_id = c("torch_ingress_num_1", "torch_ingress_num_2"),
      src_channel = c("output", "output"),
      dst_id = "nn_merge_sum",
      dst_channel = c("...", "...")
    ),
    mdsout[[1L]]$graph$edges
  )


  # two inputs two outputs


  obj = PipeOpTorchDebug$new(id = "nn_debug", inname = paste0("input", 1:2), outname = paste0("output", 1:2))
  obj$param_set$set_values(d_out1 = 2, d_out2 = 3, bias = TRUE)

  mdin1 = (po("select", selector = selector_grep("Petal")) %>>% po("torch_ingress_num_1"))$train(task)[[1L]]
  mdin2 = (po("select", selector = selector_grep("Sepal")) %>>% po("torch_ingress_num_2"))$train(task)[[1L]]

  mdouts = obj$train(list(input1 = mdin1, input2 = mdin2))
  mdout1 = mdouts[["output1"]]
  mdout2 = mdouts[["output2"]]

  expect_equal(address(mdout1$graph), address(mdout2$graph))
  expect_equal(mdout1$pointer, c("nn_debug", "output1"))
  expect_equal(mdout2$pointer, c("nn_debug", "output2"))
  expect_equal(mdout1$pointer_shape, c(NA, 2))
  expect_equal(mdout2$pointer_shape, c(NA, 3))
})

test_that("shapes_out", {
  obj = po("nn_linear", out_features = 3)

  # single input
  expect_equal(obj$shapes_out(c(NA, 1)), list(output = c(NA, 3)))
  expect_equal(obj$shapes_out(list(c(NA, 1))), list(output = c(NA, 3)))
  expect_equal(obj$shapes_out(list(input = c(NA, 1))), list(output = c(NA, 3)))
  expect_equal(obj$shapes_out(list(x = c(NA, 1))), list(output = c(NA, 3)))

  # multiple inputs
  obj1 = PipeOpTorchDebug$new()
  obj1$param_set$set_values(d_out1 = 2, d_out2 = 3)

  expect_equal(obj1$shapes_out(list(c(NA, 99), c(NA, 3))), list(output1 = c(NA, 2), output2 = c(NA, 3)))
  expect_error(obj1$shapes_out(list(c(NA, 99))), regexp = "number of input")
})

test_that("Multiple NAs are allowed in the shape", {
  graph = as_graph(po("torch_ingress_num"))

  task = tsk("iris")
  md = graph$train(task)[[1L]]

  md$pointer_shape = c(4, NA)
  md = po("nn_relu")$train(list(md))[[1L]]
  expect_equal(md$pointer_shape, c(4, NA))

  md$pointer_shape = c(NA, NA, 4)
  expect_equal(md$pointer_shape, c(NA, NA, 4))
})

test_that("only_batch_unknown", {
  obj = nn("linear", out_features = 10)
  expect_equal(obj$shapes_out(list(c(NA, NA, 1))), list(output = c(NA, NA, 10)))
  obj$.__enclos_env__$private$.only_batch_unknown = TRUE
  expect_error(obj$shapes_out(list(c(NA, NA, 1))), regexp = "Invalid shape: (NA,NA,1)", fixed = TRUE)
})

test_that("NA in second dimension", {
  ds = dataset(
    initialize = function() {
      self$xs = lapply(1:10, function(i) torch_randn(sample(1:10, 1), 3))
    },
    .getitem = function(i) {
      list(x = self$xs[[i]])
    },
    .length = function() {
      length(self$xs)
    }
  )()

  task = as_task_regr(data.table(
    x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, 3))),
    y = rnorm(10)
  ), target = "y", id = "test")

  graph = po("torch_ingress_ltnsr") %>>% po("nn_linear", out_features = 10)

  md = graph$train(task)[[1L]]

  expect_equal(md$pointer_shape, c(NA, NA, 10))

  net = model_descriptor_to_module(md)
  expect_equal(net(torch_randn(1, 2, 3))$shape, c(1, 2, 10))
  expect_equal(net(torch_randn(2, 1, 3))$shape, c(2, 1, 10))
})

Try the mlr3torch package in your browser

Any scripts or data that you put into this service are public.

mlr3torch documentation built on Aug. 26, 2025, 5:09 p.m.