Nothing
context("utils-data")
test_that("tensor dataset", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000)
data <- tensor_dataset(x, y)
expect_s3_class(data, "tensor_dataset")
expect_s3_class(data, "dataset")
expect_length(data, 1000)
sub <- data[1:2]
expect_tensor_shape(sub[[1]], c(2, 10))
expect_tensor_shape(sub[[2]], c(2))
})
test_that("create a dataset with private and active methods", {
ds <- dataset(
"my_dataset",
initialize = function() {
},
.getitem = function(i) {
private$pvt_method()
},
.length = function() {
self$actv_len
},
private = list(
pvt_method = function() {
1:10
}
),
active = list(
actv_len = function() {
10
}
)
)
d <- ds()
expect_equal(d[1], 1:10)
expect_equal(length(d), 10)
})
test_that("dataset_subset works", {
x <- torch_randn(50, 10)
y <- torch_randn(50)
data <- tensor_dataset(x, y)
data_subset <- dataset_subset(data, 1:14)
expect_equal(length(data_subset), 14)
expect_tensor(data_subset[1:4][[1]])
expect_equal(nrow(data_subset[1:4][[1]]), 4)
})
test_that("getbatch will get a vector of integers", {
data <- dataset(
initialize = function() {},
.getbatch = function(indexes) {
expect_true(is.integer(indexes))
},
.length = function() {
100
}
)
expect_true(data()[list(1, 2, 3, 4)])
})
test_that("subset without getbatch works", {
x <- torch_randn(50)
ds <- dataset("my_dataset",
initialize = function(x) {
self$x <- x
},
.getitem = function(i) {
if (length(i) > 1) {
stop("Can only get a single item!")
}
self$x[i]
},
.length = function() {
length(self$x)
}
)
data <- ds(x)
ds_subset <- dataset_subset(data, indices = 20:30)
expect_equal(data[20], ds_subset[1])
dl <- dataloader(ds_subset, batch_size = 10)
expect_length(coro::collect(dl), 2)
})
test_that("subset works with getbatch", {
x <- torch_randn(50)
ds <- dataset("my_dataset",
initialize = function(x) {
self$x <- x
},
.getbatch = function(idx) {
self$x[idx]
},
.length = function() {
length(self$x)
}
)
data <- ds(x)
ds_subset <- dataset_subset(data, indices = 20:30)
expect_equal_to_tensor(data[c(20, 21, 22)], ds_subset[c(1, 2, 3)])
dl <- dataloader(ds_subset, batch_size = 5)
expect_length(coro::collect(dl), 3)
})
test_that("datasets have a custom print method", {
data <- dataset(
initialize = function() {},
.getbatch = function(indexes) {
expect_true(is.integer(indexes))
},
.length = function() {
100
},
parent_env = .GlobalEnv
)
expect_output(print(data), regex = "dataset_generator")
})
test_that("dataset subset adds more classes", {
minimal_dataset <- dataset(
"minimal",
initialize = function() {
self$data <- torch_tensor(1:5)
},
.length = function() {
self$data$size()[[1]]
}
)
testing <- minimal_dataset()
expect_equal(class(testing), c("minimal", "dataset", "R6"))
testing_sub <- dataset_subset(testing, 1:2)
expect_equal(
class(testing_sub),
c("minimal_subset", "dataset_subset", "dataset", "R6")
)
})
test_that("dimensions are preserved", {
ds <- tensor_dataset(torch_rand(11,3), torch_rand(11,1))
loader <- dataloader(ds, batch_size=10, shuffle=TRUE)
iter <- dataloader_make_iter(loader)
x <- dataloader_next(iter)
expect_equal(dim(x[[1]]), c(10, 3))
expect_equal(dim(x[[2]]), c(10, 1))
x <- dataloader_next(iter)
expect_equal(dim(x[[1]]), c(1, 3))
expect_equal(dim(x[[2]]), c(1, 1))
})
test_that("can get a single element using `[[`", {
# this should call getitem and drop the batch dimension when possible.
ds <- tensor_dataset(torch_rand(11,3), torch_rand(11,1))
expect_equal(dim(ds[[1]][[1]]), 3)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.