tests/testthat/test-dataset-mnist.R

context("dataset-mnist")

dir <- withr::local_tempdir()

test_that("tests for the mnist dataset", {

  expect_error(
    ds <- mnist_dataset(tempfile())
  )

  ds <- mnist_dataset(dir, download = TRUE)

  i <- ds[1]
  expect_equal(dim(i[[1]]), c(28, 28))
  expect_equal(i[[2]], 6)
  expect_equal(length(ds), 60000)

  ds <- mnist_dataset(dir, transform = transform_to_tensor)
  dl <- torch::dataloader(ds, batch_size = 32)
  expect_length(dl, 1875)
  iter <- dataloader_make_iter(dl)
  i <- dataloader_next(iter)
  expect_tensor_shape(i[[1]], c(32, 28, 28))
  expect_tensor_shape(i[[2]], 32)
  expect_true((torch_max(i[[1]]) <= 1)$item())
  expect_named(i, c("x", "y"))

})

test_that("tests for the kmnist dataset", {

  expect_error(
    ds <- kmnist_dataset(tempfile())
  )

  ds <- kmnist_dataset(dir, download = TRUE)

  i <- ds[1]
  expect_equal(dim(i[[1]]), c(28, 28))
  expect_equal(i[[2]], 6)
  expect_equal(length(ds), 60000)

  ds <- kmnist_dataset(dir, transform = transform_to_tensor)
  dl <- torch::dataloader(ds, batch_size = 32)
  expect_length(dl, 1875)
  iter <- dataloader_make_iter(dl)
  i <- dataloader_next(iter)
  expect_tensor_shape(i[[1]], c(32, 28, 28))
  expect_tensor_shape(i[[2]], 32)
  expect_true((torch_max(i[[1]]) <= 1)$item())
  expect_named(i, c("x", "y"))
})

test_that("fashion_mnist_dataset loads correctly", {

  ds <- fashion_mnist_dataset(
    root = dir,
    train = TRUE,
    download = TRUE
  )

  expect_s3_class(ds, "fashion_mnist_dataset")
  expect_type(ds$.getitem(1), "list")
  expect_named(ds$.getitem(1), c("x", "y"))
  expect_equal(dim(as.array(ds$.getitem(1)$x)), c(28, 28))
  expect_true(ds$.getitem(1)$y >= 1 && ds$.getitem(1)$y <= 10)

  ds2 <- fashion_mnist_dataset(dir, transform = transform_to_tensor)
  dl <- torch::dataloader(ds2, batch_size = 32)
  iter <- dataloader_make_iter(dl)
  batch <- dataloader_next(iter)
  expect_tensor_shape(batch$x, c(32, 28, 28))
  expect_tensor_shape(batch$y, 32)
  expect_named(batch, c("x", "y"))
})

test_that("tests for the emnist dataset", {
  skip_on_cran()

  skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
        "Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

  expect_error(
    ds <- emnist_collection(root = tempfile())
  )

  emnist <- emnist_collection(dir, dataset = "balanced", download = TRUE)
  expect_equal(length(emnist), 18800)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_true(inherits(first_item$x, "array"))
  expect_equal((first_item[[2]]), 42)

  emnist <- emnist_collection(dir, dataset = "byclass", split = "test", download = TRUE)
  expect_equal(length(emnist), 116323)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_true(inherits(first_item$x, "array"))
  expect_equal(dim(first_item$x), c(28,28))
  expect_equal((first_item[[2]]), 19)

  emnist <- emnist_collection(dir, dataset = "bymerge", download = TRUE)
  expect_equal(length(emnist), 116323)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_true(inherits(first_item$x, "array"))
  expect_equal((first_item[[2]]), 25)

  emnist <- emnist_collection(dir, dataset = "letters", split = "train", download = TRUE,
                           transform = transform_to_tensor)
  expect_equal(length(emnist), 124800)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_tensor(first_item$x)
  expect_tensor_shape(first_item$x, c(1,28,28))
  expect_equal((first_item[[2]]), 24)

  emnist <- emnist_collection(dir, dataset = "digits", download = TRUE)
  expect_equal(length(emnist), 40000)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_true(inherits(first_item$x, "array"))
  expect_equal((first_item[[2]]), 1)

  emnist <- emnist_collection(dir, dataset = "mnist", split = "train", download = TRUE)
  expect_equal(length(emnist), 60000)
  first_item <- emnist[1]
  expect_named(first_item, c("x", "y"))
  expect_true(inherits(first_item$x, "array"))
  expect_equal((first_item[[2]]), 5)

  ds2 <- emnist_collection(
    root = dir,
    dataset = "balanced",
    split = "test",
    transform = transform_to_tensor,
    download = TRUE
  )
  dl <- torch::dataloader(ds2, batch_size = 32)
  iter <- torch::dataloader_make_iter(dl)
  batch <- torch::dataloader_next(iter)
  expect_tensor_shape(batch$x, c(32, 1, 28, 28))
  expect_tensor_shape(batch$y, 32)
  expect_named(batch, c("x", "y"))
})


test_that("tests for the qmnist dataset", {

  expect_error(
      ds <- qmnist_dataset(tempfile()),
      "Dataset not found."
  )

  for (split in c("train", "test", "nist")) {

    ds <- qmnist_dataset(dir, split = split, download = TRUE)

    i <- ds[1]
    expect_equal(dim(i[[1]]), c(28, 28))
    expect_true(i[[2]] %in% 1:10)

    expect_true(length(ds) > 0)

    ds <- qmnist_dataset(dir, split = split, transform = transform_to_tensor)
    dl <- torch::dataloader(ds, batch_size = 32)
    iter <- dataloader_make_iter(dl)
    i <- dataloader_next(iter)
    expect_tensor_shape(i[[1]], c(32, 28, 28))
    expect_tensor_shape(i[[2]], 32)
    expect_true((torch_max(i[[1]]) <= 1)$item())
    expect_named(i, c("x", "y"))
  }
})

test_that("tests for the emnist_dataset is deprecated", {
  skip_on_cran()

  skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
          "Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

  expect_warning(
    emnist_dataset(kind = "digits", download = TRUE),
    "'emnist_dataset' is deprecated."
  )
})

Try the torchvision package in your browser

Any scripts or data that you put into this service are public.

torchvision documentation built on Nov. 6, 2025, 9:07 a.m.