attic/test_TorchOpTabResNetBlocks.R

test_that("TorchOpTabResNetBlocks works", {
  op = top("tab_resnet_blocks")
  task = tsk("iris")

  for (i in seq_len(3)) {
    param_vals = list(
      n_blocks = sample(1:3, 1),
      d_main = sample(1:10, 1),
      d_hidden = sample(1:10, 1),
      dropout_first = runif(1),
      dropout_second = runif(1),
      activation = sample(c("relu", "elu"), 1),
      activation_args = list(inplace = sample(c(TRUE, FALSE), 1)),
      bn.momentum = 0.2,
      skip_connection = sample(c(TRUE, FALSE), 1)
    )
    n_features = sample(1:10, 1)
    n_batch = sample(1:3, 1)
    inputs = list(input = torch_randn(n_batch, n_features))
    op$param_set$values = insert_named(op$param_set$values, param_vals)

    expect_torchop(
      op = op,
      inputs = inputs,
      task = task,
      "nn_tab_resnet_blocks"
    )

  }
})
mlr-org/mlr3torch documentation built on April 17, 2025, 8:22 p.m.