Nothing
context("tfdatasets")
test_succeeds("Use tfdatasets to train a keras model", {
model <- keras_model_sequential() %>%
layer_dense(units = 1, input_shape = 1)
model %>% compile(loss='mse', optimizer='sgd')
dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
tfdatasets::dataset_repeat(100) %>%
tfdatasets::dataset_shuffle(buffer_size = 100) %>%
tfdatasets::dataset_batch(10)
if (tensorflow::tf_version() >= "2.0") {
model %>% fit(dataset, epochs = 2)
evaluate(model, dataset)
preds <- predict(model, dataset)
} else {
model %>% fit(dataset, epochs = 2, steps_per_epoch = 5)
evaluate(model, dataset, steps = 10)
preds <- predict(model, dataset, steps = 10)
}
})
test_that("Error when specifying batch_size with tfdatasets", {
skip_if_no_keras()
if (!is_tensorflow_implementation())
skip("Datasets need TensorFlow implementation.")
model <- keras_model_sequential() %>%
layer_dense(units = 1, input_shape = 1)
model %>% compile(loss='mse', optimizer='sgd')
dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
tfdatasets::dataset_repeat(100) %>%
tfdatasets::dataset_shuffle(buffer_size = 100) %>%
tfdatasets::dataset_batch(10)
expect_error(
model %>% fit(dataset, epochs = 2, batch_size = 5)
)
})
test_succeeds("Works with tf$distribute", {
if (tensorflow::tf_version() < "1.14.0")
skip("tf$distribute is not available in TF prior to v1.14")
strategy <- tensorflow::tf$distribute$MirroredStrategy()
with (strategy$scope(), {
model <- keras_model_sequential() %>%
layer_dense(units = 1, input_shape = 1)
model %>% compile(loss='mse', optimizer='sgd')
})
dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
tfdatasets::dataset_repeat(100) %>%
tfdatasets::dataset_shuffle(buffer_size = 100) %>%
tfdatasets::dataset_batch(10)
# no clean way to silence the massively verbose output this test produces
# TF_CPP_MIN_LOG_LEVEL is only effective if set before tf is initialized.
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
# https://github.com/tensorflow/tensorflow/issues/45157
# (function() {
# ol <- Sys.getenv("TF_CPP_MIN_LOG_LEVEL")
# Sys.setenv("TF_CPP_MIN_LOG_LEVEL" = "3")
# on.exit(Sys.setenv("TF_CPP_MIN_LOG_LEVEL" = ol))
model %>%
fit(dataset, epochs = 10, verbose = 0)
# })()
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.