tests/testthat/test-indexing.R

context("indexing")

test_that("[ works", {
  x <- torch_randn(c(10, 10, 10))
  expect_equal(as_array(x[1, 1, 1]), as_array(x)[1, 1, 1])
  expect_equal(as_array(x[1, , ]), as_array(x)[1, , ])
  expect_equal(as_array(x[1:5, , ]), as_array(x)[1:5, , ])
  expect_equal(as_array(x[1:10:2, , ]), as_array(x)[seq(1, 10, by = 2), , ])

  x <- torch_tensor(0:9)
  expect_equal(as_array(x[-1]$to(dtype = torch_int())), 9)
  expect_equal(as_array(x[-2:10]$to(dtype = torch_int())), c(8, 9))
  expect_equal(as_array(x[2:N]$to(dtype = torch_int())), c(1:9))

  x <- torch_randn(c(10, 10, 10, 10))
  expect_equal(as_array(x[1, ..]), as_array(x)[1, , , ])
  expect_equal(as_array(x[1, 1, ..]), as_array(x)[1, 1, , ])
  expect_equal(as_array(x[.., 1]), as_array(x)[, , , 1])
  expect_equal(as_array(x[.., 1, 1]), as_array(x)[, , 1, 1])

  x <- torch_randn(c(10, 10, 10, 10))
  i <- c(1, 2, 3, 4)
  expect_equal(as_array(x[!!!i]), as_array(x)[1, 2, 3, 4])
  i <- c(1, 2)
  expect_equal(as_array(x[!!!i, 3, 4]), as_array(x)[1, 2, 3, 4])

  x <- torch_tensor(1:10)
  y <- 1:10
  expect_equal_to_r(x[c(1, 3, 2, 5)]$to(dtype = torch_int()), y[c(1, 3, 2, 5)])

  index <- 1:3
  expect_equal_to_r(x[index]$to(dtype = torch_int()), y[index])

  x <- torch_randn(10, 10)
  x[c(2, 3, 1), c(3, 2, 1)]
  expect_equal_to_r(x[c(2, 3, 1), c(3, 2, 1)], as_array(x)[c(2, 3, 1), c(3, 2, 1)])

  x <- torch_randn(10)
  expect_equal_to_tensor(x[1:5, ..], x[1:5])

  x <- torch_randn(10)
  expect_tensor_shape(x[, NULL], c(10, 1))
  expect_tensor_shape(x[NULL, , NULL], c(1, 10, 1))
  expect_tensor_shape(x[NULL, , NULL, NULL], c(1, 10, 1, 1))

  x <- torch_randn(10)
  expect_tensor_shape(x[, newaxis], c(10, 1))
  expect_tensor_shape(x[newaxis, , newaxis], c(1, 10, 1))
  expect_tensor_shape(x[newaxis, , newaxis, newaxis], c(1, 10, 1, 1))

  x <- torch_randn(10, 10)
  expect_tensor_shape(x[1, , drop = FALSE], c(1, 10))
  expect_tensor_shape(x[.., 1, drop = FALSE], c(10, 1))
  expect_tensor_shape(x[.., -1, drop = FALSE], c(10, 1))
})

test_that("indexing error expectations", {
  x <- torch_randn(c(10, 10, 10, 10))
  expect_error(x[1, 1, 1, 1, 1])
  x <- torch_tensor(10)
  expect_error(x[0])
  expect_error(x[c(0, 1)])
})

test_that("indexing with boolean tensor", {
  x <- torch_tensor(c(-1, -2, 0, 1, 2))
  expect_equal_to_r(x[x < 0], c(-1, -2))

  x <- torch_tensor(rbind(
    c(-1, -2, 0, 1, 2),
    c(2, 1, 0, -1, -2)
  ))

  expect_equal_to_r(x[x < 0], c(-1, -2, -1, -2))

  expect_error(x[x < 0, 1])
})

test_that("slice with negative indexes", {
  x <- torch_tensor(c(1, 2, 3))
  expect_equal_to_r(x[2:-1], c(2, 3))
  expect_equal_to_r(x[-2:-1], c(2, 3))
  expect_equal_to_r(x[-3:-2], c(1, 2))

  expect_equal_to_r(x[c(-1, -2)], c(3, 2))
})

test_that("subset assignment", {
  x <- torch_randn(2, 2)
  x[1, 1] <- torch_tensor(0)
  x
  expect_equal_to_r(x[1, 1], 0)

  x[1, 2] <- 0
  expect_equal_to_r(x[1, 2], 0)

  x[1, 2] <- 1L
  expect_equal_to_r(x[1, 2], 1)

  x <- torch_tensor(c(TRUE, FALSE))
  x[2] <- TRUE
  expect_equal_to_r(x[2], TRUE)

  x <- torch_tensor(rbind(
    c(-1, -2, 0, 1, 2),
    c(2, 1, 0, -1, -2)
  ))

  x[x <= 0] <- 1
  expect_true(as_array(torch_all(x > 0)))

  x <- torch_tensor(c(1, 2, 3, 4, 5))
  x[1:2] <- c(0, 0)
  expect_equal_to_r(x[1:2], c(0, 0))
})

test_that("Subset assignment, scalar + tensor", {
  x <- torch_randn(3,2)
  x_r <- as.array(x)
  
  mask <- torch_tensor(c(TRUE, FALSE, TRUE))
  y <- torch_tensor(c(1, 2))
  x[mask, 1] <- y

  x_r[c(TRUE, FALSE, TRUE), 1] <- c(1,2)
  expect_equal_to_r(x, x_r)

  # make sure it works in the reverse order
  x <- torch_randn(2,3)
  x_r <- as.array(x)

  mask <- torch_tensor(c(TRUE, FALSE, TRUE))
  y <- torch_tensor(c(1, 2))
  x[1, mask] <- y

  x_r[1, c(TRUE, FALSE, TRUE)] <- c(1,2)
  expect_equal_to_r(x, x_r)
})

test_that("subset assignment tensor x tensor", {
  x <- torch_randn(3, 2)
  x_r <- as.array(x)
  
  expect_error(
    x[torch_tensor(c(1L,3L)), torch_tensor(c(1L,2L))] <- 2,
    "Only one tensor is allowed in the index"
  )

  expect_error(
    x[torch_tensor(c(1L,3L)), torch_tensor(c(1L,2L))] <- torch_tensor(2),
    "Only one tensor is allowed in the index"
  )
})

test_that("indexing with R boolean vectors", {
  x <- torch_tensor(c(1, 2))
  expect_equal_to_r(x[c(TRUE, FALSE)], 1)
  x <- torch_tensor(c(1))
  expect_equal_to_r(x[TRUE], 1)

  x <- torch_zeros(2, 2)
  expect_equal(dim(x[c(TRUE, FALSE),]), c(1,2))
  expect_equal(dim(x[c(TRUE, FALSE),c(TRUE, FALSE)]), c(1,1))
})

test_that("indexing with long tensors", {
  x <- torch_randn(4, 4)
  index <- torch_tensor(1, dtype = torch_long())
  expect_equal(x[index, index]$item(), x[1, 1]$item())
  expect_tensor_shape(x[index, index], c(1, 1))

  index <- torch_scalar_tensor(1, dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[1, 1])

  index <- torch_tensor(-1, dtype = torch_long())
  expect_equal(x[index, index]$item(), x[-1, -1]$item())
  expect_tensor_shape(x[index, index], c(1, 1))

  index <- torch_scalar_tensor(-1, dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[-1, -1])

  index <- torch_tensor(c(-1, 1), dtype = torch_long())
  expect_equal_to_tensor(x[index, index], x[c(-1, 1), c(-1, 1)])

  index <- torch_tensor(c(-1, 0, 1), dtype = torch_long())
  expect_error(x[index, ], regexp = "Indexing starts at 1")
})

test_that("can use the slc construct", {
  x <- torch_randn(10, 10)
  r <- as_array(x)

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), ],
    r[seq(1, 5, by = 2), ]
  )

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), 1],
    r[seq(1, 5, by = 2), 1]
  )

  expect_equal_to_r(
    x[slc(start = 1, end = 5, step = 2), slc(start = 1, end = 5, step = 2)],
    r[seq(1, 5, by = 2), seq(1, 5, by = 2)]
  )

  expect_equal_to_tensor(
    x[slc(2, Inf), ],
    x[2:N, ]
  )
})

test_that("print slice", {
  testthat::local_edition(3)
  expect_snapshot(print(slc(1, 3, 5)))
})

test_that("mix vector indexing with slices and others", {
  x <- torch_randn(3, 3, 3)
  expect_equal_to_tensor(
    x[c(1, 2), 1:2, c(1, 2)],
    x[1:2, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), newaxis, 1:2, c(1, 2)],
    x[1:2, newaxis, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[newaxis, c(1, 2), newaxis, 1:2, c(1, 2)],
    x[newaxis, 1:2, newaxis, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), ],
    x[1:2, 1:2, ]
  )

  expect_equal_to_tensor(
    x[c(1, 2), , c(1, 2)],
    x[1:2, , 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), c(1, 2)],
    x[1:2, 1:2, 1:2]
  )

  expect_equal_to_tensor(
    x[c(1, 2), c(1, 2), newaxis, c(1, 2)],
    x[1:2, 1:2, newaxis, 1:2]
  )
})

test_that("boolean vector indexing works as expected", {
  x <- torch_randn(4, 4, 4)
  index <- c(TRUE, FALSE, TRUE, FALSE)

  expect_equal_to_r(
    x[index, index, index],
    as_array(x)[index, index, index]
  )
})

test_that("regression test for #691", {
  a <- torch_randn(c(6, 4))
  b <- c(1, 2, 3)
  a[b]
  expect_equal(b, c(1, 2, 3))
})

test_that("regression test for #695", {
  a <- torch_randn(c(3, 4, 2))
  b <- torch_tensor(c(1, 3), dtype = torch_long())

  expect_equal_to_r(
    a[.., b, ],
    as.array(a)[, c(1, 3), ]
  )

  a <- torch_randn(c(3, 4, 3))

  expect_equal_to_r(
    a[.., b, b],
    as.array(a)[, c(1, 3), c(1, 3)]
  )

  expect_equal_to_r(
    a[b, .., b],
    as.array(a)[c(1, 3), , c(1, 3)]
  )
})

test_that("NULL tensor", {
  
  x <- torch_tensor(NULL)
  expect_true(x$dtype == torch_bool())
  expect_equal(x$shape, 0)
  
  # subsetting shouldn't crash
  expect_error(x[1], regexp = "out of bounds")
  expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds")
  
})

test_that("works with numeric /logic matrix", {
  # Regression test for: https://github.com/mlverse/torch/issues/1181
  x <- torch_randn(4, 4)
  y <- rbind(c(1, 1), c(1,2))
  
  expect_true(
    torch_allclose(
      x[y],
      x[torch_tensor(y, dtype = "long")]
    )
  )
  
  expect_true(
    torch_allclose(
      x[x > 0],
      x[as.array(x>0)]
    )
  )

  # also test if it works when the tensor is in a different device
  skip_if_not_m1_mac()
  x <- x$to(device="mps")
  
  expect_true(
    torch_allclose(
      x[y],
      x[torch_tensor(y, dtype = "long")]
    )
  )
  
  expect_true(
    torch_allclose(
      x[x > 0],
      x[as.array(x>0)]
    )
  )
})

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on Nov. 5, 2025, 7:02 p.m.