Nothing
context("utils-data-dataloader")
test_that("dataloader works", {
x <- torch_randn(1000, 100)
y <- torch_randn(1000, 1)
dataset <- tensor_dataset(x, y)
dl <- dataloader(dataset = dataset, batch_size = 32)
expect_length(dl, 1000 %/% 32 + 1)
expect_true(is_dataloader(dl))
iter <- dl$.iter()
b <- iter$.next()
expect_tensor_shape(b[[1]], c(32, 100))
expect_tensor_shape(b[[2]], c(32, 1))
iter <- dl$.iter()
for (i in 1:32) {
k <- iter$.next()
}
expect_equal(iter$.next(), coro::exhausted())
})
test_that("dataloader iteration", {
x <- torch_randn(100, 100)
y <- torch_randn(100, 1)
dataset <- tensor_dataset(x, y)
dl <- dataloader(dataset = dataset, batch_size = 32)
# iterating with a while loop
iter <- dataloader_make_iter(dl)
while (!is.null(batch <- dataloader_next(iter))) {
expect_tensor(batch[[1]])
expect_tensor(batch[[2]])
}
expect_warning(class = "deprecated", {
# iterating with an enum
for (batch in enumerate(dl)) {
expect_tensor(batch[[1]])
expect_tensor(batch[[2]])
}
})
})
test_that("can have datasets that don't return tensors", {
ds <- dataset(
initialize = function() {},
.getitem = function(index) {
list(
matrix(runif(10), ncol = 10),
index,
1:10
)
},
.length = function() {
100
}
)
d <- ds()
dl <- dataloader(d, batch_size = 32, drop_last = TRUE)
# iterating with an enum
expect_warning(class = "deprecated", {
for (batch in enumerate(dl)) {
expect_tensor_shape(batch[[1]], c(32, 1, 10))
expect_true(batch[[1]]$dtype == torch_float())
expect_tensor_shape(batch[[2]], c(32))
expect_tensor_shape(batch[[3]], c(32, 10))
expect_true(batch[[3]]$dtype == torch_long())
}
})
expect_true(batch[[1]]$dtype == torch_float32())
expect_true(batch[[2]]$dtype == torch_int64())
expect_true(batch[[3]]$dtype == torch_int64())
})
test_that("dataloader that shuffles", {
x <- torch_randn(100, 100)
y <- torch_randn(100, 1)
d <- tensor_dataset(x, y)
dl <- dataloader(dataset = d, batch_size = 50, shuffle = TRUE)
expect_warning(class = "deprecated", {
for (i in enumerate(dl)) {
expect_tensor_shape(i[[1]], c(50, 100))
}
})
dl <- dataloader(dataset = d, batch_size = 30, shuffle = TRUE)
j <- 0
expect_warning(class = "deprecated", {
for (i in enumerate(dl)) {
j <- j + 1
if (j == 4) {
expect_tensor_shape(i[[1]], c(10, 100))
} else {
expect_tensor_shape(i[[1]], c(30, 100))
}
}
})
})
test_that("named outputs", {
ds <- dataset(
initialize = function() {
},
.getitem = function(i) {
list(x = i, y = 2 * i)
},
.length = function() {
1000
}
)()
expect_named(ds[1], c("x", "y"))
dl <- dataloader(ds, batch_size = 4)
iter <- dataloader_make_iter(dl)
expect_named(dataloader_next(iter), c("x", "y"))
})
test_that("can use a dataloader with coro", {
ds <- dataset(
initialize = function() {
},
.getitem = function(i) {
list(x = i, y = 2 * i)
},
.length = function() {
10
}
)()
expect_named(ds[1], c("x", "y"))
dl <- dataloader(ds, batch_size = 5)
j <- 1
loop(for (batch in dl) {
expect_named(batch, c("x", "y"))
expect_tensor_shape(batch$x, 5)
expect_tensor_shape(batch$y, 5)
})
})
test_that("dataloader works with num_workers", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
list(x = .worker_info$id)
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
it <- dataloader_make_iter(dl)
i <- 1
expect_warning(class = "deprecated", {
for (batch in enumerate(dl)) {
expect_equal_to_tensor(batch$x, i * torch_ones(10))
i <- i + 1
}
})
})
test_that("dataloader catches errors on workers", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
stop("the error id is 5567")
list(x = .worker_info$id)
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
iter <- dataloader_make_iter(dl)
expect_error(
dataloader_next(iter),
class = "runtime_error",
regexp = "5567"
)
})
test_that("woprker init function is respected", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
list(x = theid)
}
)
worker_init_fn <- function(id) {
theid <<- id * 2
}
dl <- dataloader(ds(),
batch_size = 10, num_workers = 2,
worker_init_fn = worker_init_fn
)
i <- 1
expect_warning(class = "deprecated", {
for (batch in enumerate(dl)) {
expect_equal_to_tensor(batch$x, i * 2 * torch_ones(10))
i <- i + 1
}
})
})
test_that("dataloader timeout is respected", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
Sys.sleep(10)
list(x = 1)
}
)
dl <- dataloader(ds(),
batch_size = 10, num_workers = 2,
timeout = 5
) # (timeout is in miliseconds)
iter <- dataloader_make_iter(dl)
expect_error(
dataloader_next(iter),
class = "runtime_error",
regexp = "timed out"
)
})
test_that("can return tensors in multiworker dataloaders", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
list(x = torch_scalar_tensor(1))
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
expect_warning(class = "deprecated", {
for (batch in enumerate(dl)) {
expect_equal_to_tensor(batch$x, torch_ones(10))
}
})
})
test_that("can make reproducible runs", {
if (cuda_is_available()) {
skip_on_os("windows")
}
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
list(x = runif(1), y = torch_randn(1))
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
set.seed(1)
iter <- dataloader_make_iter(dl)
b1 <- dataloader_next(iter)
set.seed(1)
iter <- dataloader_make_iter(dl)
b2 <- dataloader_next(iter)
expect_equal(b1$x, b2$x)
expect_equal_to_tensor(b1$y, b2$y)
})
test_that("load packages in dataloader", {
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
torch_tensor("coro" %in% (.packages()))
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
iter <- dataloader_make_iter(dl)
b1 <- dataloader_next(iter)
expect_equal(torch_any(b1)$item(), FALSE)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_packages = "coro")
iter <- dataloader_make_iter(dl)
b1 <- dataloader_next(iter)
expect_equal(torch_all(b1)$item(), TRUE)
})
test_that("globals can be found", {
ds <- dataset(
.length = function() {
20
},
initialize = function() {},
.getitem = function(id) {
hello_fn()
}
)
dl <- dataloader(ds(), batch_size = 10, num_workers = 2)
iter <- dataloader_make_iter(dl)
expect_error(
b1 <- dataloader_next(iter)
)
expect_error(
dl <- dataloader(ds(),
batch_size = 10, num_workers = 2,
worker_globals = c("hello", "world")
),
class = "runtime_error"
)
hello_fn <- function() {
torch_randn(5, 5)
}
dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_globals = list(
hello_fn = hello_fn
))
iter <- dataloader_make_iter(dl)
expect_tensor_shape(dataloader_next(iter), c(10, 5, 5))
dl <- dataloader(ds(),
batch_size = 10, num_workers = 2,
worker_globals = "hello_fn"
)
iter <- dataloader_make_iter(dl)
expect_tensor_shape(dataloader_next(iter), c(10, 5, 5))
})
test_that("datasets can use an optional .getbatch method for speedups", {
d <- dataset(
initialize = function() {},
.getbatch = function(indexes) {
list(
torch_randn(length(indexes), 10),
torch_randn(length(indexes), 1)
)
},
.length = function() {
100
}
)
dl <- dataloader(d(), batch_size = 10)
coro::loop(for (x in dl) {
expect_length(x, 2)
expect_tensor_shape(x[[1]], c(10, 10))
expect_tensor_shape(x[[2]], c(10, 1))
})
})
test_that("dataloaders handle .getbatch that don't necessarily return a torch_tensor", {
d <- dataset(
initialize = function() {},
.getbatch = function(indexes) {
list(
array(0, dim = c(length(indexes), 10)),
array(0, dim = c(length(indexes), 1))
)
},
.length = function() {
100
}
)
dl <- dataloader(d(), batch_size = 10)
coro::loop(for (x in dl) {
expect_length(x, 2)
expect_tensor_shape(x[[1]], c(10, 10))
expect_tensor_shape(x[[2]], c(10, 1))
})
})
test_that("a value error is returned when its not possible to convert", {
d <- dataset(
initialize = function() {},
.getbatch = function(indexes) {
"a"
},
.length = function() {
100
}
)
expect_error(
dataloader_next(dataloader_make_iter(dataloader(d(), batch_size = 10))),
regexp = "Can't convert data of class.*",
class = "value_error"
)
})
test_that("warning tensor", {
dt <- dataset(
initialize = function() {
self$x <- torch_randn(100, 100)
private$k <- torch_randn(10, 10)
self$z <- list(
k = torch_tensor(1),
torch_tensor(2)
)
},
.getitem = function(i) {
torch_randn(1, 1)
},
.length = function() {
100
},
active = list(
y = function() {
torch_randn(1)
}
),
private = list(
k = 1
)
)
dt <- dt()
expect_warning(
x <- dataloader(dt, batch_size = 2, num_workers = 10),
regexp = "parallel dataloader"
)
})
test_that("collate works with bool data", {
data <- replicate(10, torch_randn(5))
example_ds <- dataset(
"example_dataset",
initialize = function(numbers) {
self$numbers <- numbers
},
.getitem = function(i) {
x <- self$numbers[[i]]
list(x = x, n = as.array(torch_mean(x) > 0))
},
.length = function() length(self$numbers)
)
example_ds_inst <- example_ds(numbers = data)
expect_true(is.logical(example_ds_inst[1]$n))
example_dl <- dataloader(
example_ds_inst,
batch_size = 2
)
out <- coro::collect(example_dl, 1)[[1]]$n
expect_true(out$dtype == torch_bool())
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.