
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 2)
# Sys.setenv(CUDA_VISIBLE_DEVICES = 0)
# Sys.setenv(TF_XLA_FLAGS='--tf_xla_cpu_global_jit')

tf_function <- function(fn, ...)
  tf$`function`(reticulate::py_func(fn), ..., autograph = FALSE)

tuple <- reticulate::tuple

as_tensor <- function(x, ...) tf$convert_to_tensor(x, ...)

grab <- function(x) {
  if(tf$executing_eagerly()) {
    return(rapply(list(x), function(tensor) tensor$numpy(),
                  classes = "tensorflow.tensor",
                  how = "replace")[[1]])

    .SESS <<- tf$compat$v1$Session()

expect_result <- function(fun, inputs, expected) {
    inputs <- list(inputs)
  tensor_result <- do.call(fun, inputs)
  result <- grab(tensor_result)
  expect_equal(result, expected)

expect_grabbed_result_equal <- function(tensor, value)
  expect_equal(grab(tensor), value)

seq_len0 <- function(x) if(x == 0L) integer() else 0L:(x - 1L)

`add<-`      <- function(x, value) x + value
`subtract<-` <- function(x, value) x - value
`multiply<-` <- function(x, value) x * value
`divide<-`   <- function(x, value) x / value

expect_ag_equivalent <- function(fn, input) {
  if (!is.list(input))
    input <- list(input)
  ag_fn <- autograph(fn)
  tf_ag_fn <- tf_function(ag_fn)

  ag_input <- lapply(input, as_tensor)

  res       <-      do.call(fn, input)
  ag_res    <- grab(do.call(ag_fn, ag_input))
  tf_ag_res <- grab(do.call(tf_ag_fn, ag_input))
  expect_equal(ag_res, res)
  expect_equal(tf_ag_res, res)

np_arr <- function(...)
  reticulate::np_array(array(seq_len(prod(...)), c(...)), "float32")

tf_arr <- function(...)

skip_if_no_tensorflow <- function() {
  if (!reticulate::py_module_available("tensorflow"))
    skip("TensorFlow not available for testing")

Try the tfautograph package in your browser

Any scripts or data that you put into this service are public.

tfautograph documentation built on Sept. 18, 2021, 1:07 a.m.