tests/testthat/test-ivalue.R

test_that("ivalue works for bool", {
  expect_equal(ivalue_test_function(TRUE), TRUE)
  expect_equal(ivalue_test_function(FALSE), FALSE)
})

test_that("ivalue works for bool list", {
  x <- sample(c(TRUE, FALSE), replace = TRUE, 10)

  expect_equal(ivalue_test_function(x), x)
})

test_that("works for int", {
  x <- 1L
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for int list", {
  x <- 1:100
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for double", {
  x <- runif(1)
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for double list", {
  x <- runif(100)
  expect_equal(ivalue_test_function(x), x)
})

test_that("tensor works", {
  x <- torch_randn(100, 100)
  expect_equal_to_tensor(ivalue_test_function(x), x)
})

test_that("tensor list works", {
  x <- list(torch_tensor(1), torch_tensor(2))
  y <- ivalue_test_function(x)

  expect_equal_to_tensor(y[[1]], x[[1]])
  expect_equal_to_tensor(y[[2]], x[[2]])
})

test_that("works for strings", {
  x <- "hello"
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for unamed lists", {
  x <- list(
    1L, 1:10, 1.1, c(2, 3.3, 4.4), torch_tensor(1),
    list(torch_tensor(1), torch_tensor(2)),
    TRUE, c(TRUE, FALSE)
  )
  x[[length(x) + 1]] <- x
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for dicts of tensors", {
  x <- list(a = torch_tensor(1), b = torch_tensor(2))
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for named list of different types", {
  x <- list(a = torch_tensor(1), b = 1)
  expect_equal(ivalue_test_function(x), x)
})

test_that("works for named tuples", {
  x <- list(a = torch_tensor(1), b = torch_tensor(2))
  class(x) <- "jit_tuple"
  expect_equal(ivalue_test_function(x), unclass(x))
})

test_that("works for tuples", {
  x <- list(torch_tensor(1), torch_tensor(2))
  class(x) <- "jit_tuple"
  expect_equal(ivalue_test_function(x), unclass(x))
})

test_that("works with NULL", {
  expect_equal(ivalue_test_function(NULL), NULL)
})

test_that("works with scalars", {
  expect_equal(ivalue_test_function(jit_scalar(1)), 1)
  expect_equal(ivalue_test_function(jit_scalar(1L)), 1L)
  expect_equal(ivalue_test_function(jit_scalar(TRUE)), TRUE)
})

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on June 7, 2023, 6:19 p.m.