tests/testthat/test-utils-data-sampler.R

context("utils-data-sampler")

test_that("sampler's length", {
  x <- torch_randn(1000, 10)
  y <- torch_randn(1000)
  data <- tensor_dataset(x, y)

  sampler <- SequentialSampler(data)
  expect_length(sampler, 1000)

  sampler <- RandomSampler(data, num_samples = 10)
  expect_length(sampler, 10)

  sampler <- RandomSampler(data)
  expect_length(sampler, 1000)

  batch <- BatchSampler(sampler = sampler, batch_size = 32, drop_last = TRUE)
  expect_length(batch, 1000 %/% 32)

  batch <- BatchSampler(sampler = sampler, batch_size = 32, drop_last = FALSE)
  expect_length(batch, 1000 %/% 32 + 1)

  batch <- BatchSampler(sampler = sampler, batch_size = 100, drop_last = FALSE)
  expect_length(batch, 10)

  batch <- BatchSampler(sampler = sampler, batch_size = 1000, drop_last = FALSE)
  expect_length(batch, 1)

  batch <- BatchSampler(sampler = sampler, batch_size = 1001, drop_last = FALSE)
  expect_length(batch, 1)

  batch <- BatchSampler(sampler = sampler, batch_size = 1001, drop_last = TRUE)
  expect_length(batch, 0)
})

test_that("Random sampler, replacement = TRUE", {
  x <- torch_randn(2, 10)
  y <- torch_randn(2)
  data <- tensor_dataset(x, y)

  x <- RandomSampler(data, replacement = TRUE)
  it <- x$.iter()

  for (i in 1:length(x)) {
    k <- it()
    expect_true(k <= 2 && k >= 1)
  }

  expect_equal(it(), coro::exhausted())
})

test_that("Batch sampler", {
  x <- torch_randn(100, 10)
  y <- torch_randn(2)
  data <- tensor_dataset(x, y)

  r <- RandomSampler(data, replacement = FALSE)
  x <- BatchSampler(r, 32, TRUE)
  it <- x$.iter()

  expect_length(it(), 32)
  expect_length(it(), 32)
  expect_length(it(), 32)
  expect_equal(it(), coro::exhausted())
})

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on June 7, 2023, 6:19 p.m.