Nothing
test_that("max with indices", {
x <- torch_tensor(c(5, 6, 7, 8))
m <- torch_max(x, dim = 1)
expect_equal_to_r(m[[2]]$to(dtype = torch_int()), 4)
expect_equal_to_r(
torch_max(c(2, 1), other = c(1, 2)),
c(2, 2)
)
})
test_that("min with indices", {
x <- torch_tensor(c(5, 6, 7, 8))
m <- torch_min(x, dim = 1)
expect_equal_to_r(m[[2]]$to(dtype = torch_int()), 1)
expect_equal_to_r(
torch_min(c(2, 1), other = c(1, 2)),
c(1, 1)
)
})
test_that("argsort", {
x <- torch_tensor(c(3, 2, 1))
expect_equal_to_r(torch_argsort(x), c(3, 2, 1))
expect_equal_to_r(x$argsort(), c(3, 2, 1))
x <- torch_tensor(c(1, 2, 3))
expect_equal_to_r(torch_argsort(x, descending = TRUE), c(3, 2, 1))
expect_equal_to_r(x$argsort(descending = TRUE), c(3, 2, 1))
x <- torch_tensor(1:10)$view(c(5, 2))
expect_equal_to_r(torch_argsort(x, dim = 1)[, 1], 1:5)
expect_equal_to_r(x$argsort(dim = 1)[, 1], 1:5)
expect_equal_to_r(torch_argsort(x, dim = 2)[, 1], rep(1, 5))
expect_equal_to_r(x$argsort(dim = 2)[, 1], rep(1, 5))
})
test_that("argmax", {
x <- torch_tensor(c(1, 2, 3))
expect_equal_to_r(torch_argmax(x), 3)
expect_equal_to_r(x$argmax(), 3)
x <- torch_tensor(c(3, 2, 1))
expect_equal_to_r(torch_argmax(x), 1)
expect_equal_to_r(x$argmax(), 1)
x <- torch_tensor(1:9)$reshape(c(3, 3))
expect_equal_to_r(torch_argmax(x, dim = 2), c(3, 3, 3))
expect_equal(torch_argmax(x, dim = 2, keepdim = TRUE)$shape, c(3, 1))
})
test_that("argmin", {
x <- torch_tensor(c(1, 2, 3))
expect_equal_to_r(torch_argmin(x), 1)
expect_equal_to_r(x$argmin(), 1)
x <- torch_tensor(c(3, 2, 1))
expect_equal_to_r(torch_argmin(x), 3)
expect_equal_to_r(x$argmin(), 3)
x <- torch_tensor(1:9)$reshape(c(3, 3))
expect_equal_to_r(torch_argmin(x, dim = 2), c(1, 1, 1))
expect_equal(torch_argmin(x, dim = 2, keepdim = TRUE)$shape, c(3, 1))
})
test_that("sort", {
x <- torch_tensor(sample(1e2))
expect_equal_to_r(torch_sort(x)[[2]], order(as.integer(x)))
expect_equal_to_r(torch_sort(x, descending = TRUE)[[2]], order(as.integer(x), decreasing = TRUE))
expect_equal_to_r(x$sort()[[2]], order(as.integer(x)))
expect_equal_to_r(x$sort(descending = TRUE)[[2]], order(as.integer(x), decreasing = TRUE))
})
test_that("bincount is 1 indexed", {
x <- torch_tensor(c(1,2,3,1), dtype = torch_int64())
out <- torch_bincount(x)
expect_length(out, 3)
out <- x$bincount()
expect_length(out, 3)
x <- torch_tensor(c(1,2,3,1,0), dtype = torch_int64())
expect_error({
out <- torch_bincount(x)
}, regexp = "Indexing starts at 1 but found a 0.")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.