context("iterators")
test_that("make_iterator_one_shot works", {
skip_if_no_tensorflow()
batch <- mtcars_dataset() %>%
make_iterator_one_shot() %>%
iterator_get_next()
res <- if (tf$executing_eagerly()) {
as.array(batch$disp)
} else {
with_session(function (sess) {
sess$run(batch)$disp
})
}
expect_type(res, "double")
})
test_that("make_iterator_initializable works", {
skip_if_no_tensorflow()
skip_if_eager("dataset.make_initializable_iterator is not supported when eager execution is enabled.")
with_session(function(sess) {
if (tensorflow::tf_version() < "1.14")
max_value <- tf$placeholder(tf$int64, shape = shape())
else
max_value <- tf$compat$v1$placeholder(tf$int64, shape = shape())
range_ds <- range_dataset(from = 1, to = max_value)
iterator <- range_ds %>%
make_iterator_initializable()
next_element <- iterator_get_next(iterator)
iterator %>%
iterator_initializer() %>%
sess$run(feed_dict = dict(max_value = 10L))
for (i in 1L:9L) {
value <- sess$run(next_element)
expect_equal(i, value)
}
iterator %>%
iterator_initializer() %>%
sess$run(feed_dict = dict(max_value = 20L))
for (i in 1L:19L) {
value <- sess$run(next_element)
expect_equal(i, value)
}
})
})
test_succeeds("make_iterator_from_structure works", {
skip_if_no_tensorflow()
training_dataset <- range_dataset(from = 1, to = 100) %>%
dataset_map(function(x) {
x + tfr_random_uniform(shape(), -10L, 10L, tf$int64)
})
validation_dataset = range_dataset(from = 1, to = 50)
iterator <- make_iterator_from_structure(output_types(training_dataset),
output_shapes(training_dataset))
if (tf$executing_eagerly()) {
for (i in 1:20) {
# TODO: this emits a warning about incorrect context for Iterator.get_next(), investigate
# Initialize an iterator over the training dataset.
iterator_make_initializer(iterator, training_dataset)
for (j in 1:99)
iterator_get_next(iterator)
# Initialize an iterator over the validation dataset.
iterator_make_initializer(iterator, validation_dataset)
for (j in 1:49)
iterator_get_next(iterator)
}
} else {
with_session(function(sess) {
next_element <- iterator_get_next(iterator)
training_init_op <- iterator_make_initializer(iterator, training_dataset)
validation_init_op <- iterator_make_initializer(iterator, validation_dataset)
for (i in 1:20) {
# Initialize an iterator over the training dataset.
sess$run(training_init_op)
for (j in 1:99)
sess$run(next_element)
# Initialize an iterator over the validation dataset.
sess$run(validation_init_op)
for (j in 1:49)
sess$run(next_element)
}
})
}
})
test_succeeds("make_iterator_from_string_handle works", {
skip_if_no_tensorflow()
skip_if_eager("EagerIterator object has no attribute string_handle")
with_session(function(sess) {
training_dataset <- range_dataset(from = 1, to = 100) %>%
dataset_map(function(x) {
x + tfr_random_uniform(shape(), -10L, 10L, tf$int64)
}) %>%
dataset_repeat()
validation_dataset = range_dataset(from = 1, to = 50)
if (tensorflow::tf_version() < "1.14")
handle <- tf$placeholder(tf$string, shape = shape())
else
handle <- tf$compat$v1$placeholder(tf$string, shape = shape())
iterator <- make_iterator_from_string_handle(
handle,
output_types(training_dataset),
output_shapes(training_dataset)
)
next_element <- iterator_get_next(iterator)
training_iterator <- make_iterator_one_shot(training_dataset)
validation_iterator <- make_iterator_initializable(validation_dataset)
training_handle <- sess$run(iterator_string_handle(training_iterator))
validation_handle <- sess$run(iterator_string_handle(validation_iterator))
for (i in 1:2) {
for (j in 1:199)
sess$run(next_element, feed_dict = dict(handle = training_handle))
sess$run(iterator_initializer(validation_iterator))
for (j in 1:49)
sess$run(next_element, feed_dict = dict(handle = validation_handle))
}
})
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.