tests/testthat/test-folder-dataset.R

test_that("image_folder dataset", {

  ds <- image_folder_dataset(
    root = "assets/class",
    transform = . %>% transform_to_tensor %>%
      transform_resize(c(32,32))
  )
  expect_length(ds[1], 2)

  dl <- torch::dataloader(ds, batch_size = 2, drop_last = TRUE)
  coro::loop(for(batch in dl) {
    expect_tensor_shape(batch[[1]], c(2, 3, 32, 32))
    expect_tensor_shape(batch[[2]], 2)
    expect_tensor_shape(batch$x, c(2, 3, 32, 32))
    expect_tensor_shape(batch$y, 2)
  })

  expect_length(ds, 12)

})

Try the torchvision package in your browser

Any scripts or data that you put into this service are public.

torchvision documentation built on June 22, 2024, 11:25 a.m.