Nothing
skip_if(getRversion() <= "4.0.0")
test_that("comparison takes requires_grad into account", {
testthat::local_edition(3)
expect_equal(
torch_tensor(1)$requires_grad_(FALSE),
torch_tensor(1)$requires_grad_(FALSE)
)
expect_equal(
torch_tensor(1)$requires_grad_(TRUE),
torch_tensor(1)$requires_grad_(TRUE)
)
expect_failure(expect_equal(
torch_tensor(1)$requires_grad_(FALSE),
torch_tensor(1)$requires_grad_(TRUE)
))
})
test_that("comparison takes tensor's value into account", {
testthat::local_edition(3)
expect_failure(expect_equal(
torch_tensor(1),
torch_tensor(2)
))
})
test_that("comparison takes tensor's dimension into account", {
testthat::local_edition(3)
expect_failure(expect_equal(
torch_tensor(1)$reshape(c(1, 1)),
torch_tensor(1)$reshape(1)
))
})
test_that("grad_fn is respected", {
testthat::local_edition(3)
x = torch_tensor(1)$requires_grad_(TRUE)
# grad_fn is changed after cloning
expect_failure(expect_equal(
x,
x$clone()
))
# without requires_grad, grad_fn is not changed
x = torch_tensor(1)
# grad_fn is changed
expect_equal(
x,
x$clone()
)
})
test_that("compare tensors using cuda", {
skip_if_cuda_not_available()
testthat::local_edition(3)
expect_failure(expect_equal(
torch_tensor(1)$cuda(),
torch_tensor(1)$cpu()
))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.