tests/testthat/helper-utils.R

library(tensorflow)


skip_if_no_tensorflow <- function(required_version = NULL) {
  if (!reticulate::py_module_available("tensorflow"))
    skip("TensorFlow not available for testing")
  else if (!is.null(required_version)) {
    if (tensorflow::tf_version() < required_version)
      skip(sprintf("Required version of TensorFlow (%s) not available for testing",
                   required_version))
  }
}

skip_if_eager <- function(message) {
  if (tf$executing_eagerly())
    skip(message)
}

skip_if_not_eager <- function() {
  if (!tf$executing_eagerly())
    skip("Eager execution is required. ")
}

skip_if_v2 <- function(message) {
  if (tensorflow::tf_version() >= "2.0")
    skip(message)
}

skip_tfestimators <- function() {
  skip("tfestimators deprecated")
}

py_capture_output <- reticulate::py_capture_output

test_succeeds <- function(desc, expr, required_version = NULL) {
  # IPython <- reticulate::import("IPython")
  # py_capture_output <- IPython$utils$capture$capture_output
  invisible(
    capture.output({
      py_capture_output({
        test_that(desc, {
          skip_if_no_tensorflow(required_version)
          expect_error(force(expr), NA)
        })
      })
    })
  )
}

csv_dataset <- function(file, ...) {
  csv_spec <- csv_record_spec(file, ...)
  text_line_dataset(file, record_spec = csv_spec)
}

mtcars_dataset <- function() {
  testing_data_filepath("mtcars.csv") %>%
    csv_dataset() %>%
    dataset_shuffle(50) %>%
    dataset_batch(10)
}

mtcars_dataset_nobatch <- function() {
  testing_data_filepath("mtcars.csv") %>%
    csv_dataset() %>%
    dataset_shuffle(50)
}


testing_data_filepath <- function(x = NULL) {
  # basename(getwd()) is either "tfdatasets" or "tests" if in R CMD check
  d <- if(file.exists("data/iris.csv"))
     "data" else "tests/testthat/data"
  if (is.null(x))
    d
  else
    file.path(d, x)
}
rstudio/tfdatasets documentation built on July 22, 2024, 12:41 a.m.