tests/testthat/test-tensorflow-rpkg-stability.R

set.seed(2020 - 02 - 11)

test_that("tensorflow returns appropriate thing with 'dim'", {
  skip_if_not(check_tf_version())
  xt_int_32 <- tensorflow::as_tensor(x = 42, "int32")
  xt_int_64 <- tensorflow::as_tensor(x = 42, "int64")
  xt_float_32 <- tensorflow::as_tensor(x = 42, "float32")
  xt_float_32_dec <- tensorflow::as_tensor(x = 42.2, "float32")

  expect_equal(dim(xt_int_32), integer(0))
  expect_equal(dim(xt_int_64), integer(0))
  expect_equal(dim(xt_float_32), integer(0))
  expect_equal(dim(xt_float_32_dec), integer(0))
  expect_equal(dim(shape(1,2,3)), NULL)
  expect_equal(dim(tensorflow::as_tensor(c(1:3))), 3)
})

test_that("Tensor behaves as we expect",{
  skip_if_not(check_tf_version())
  x <- tensorflow::as_tensor(42, "int32")
  expect_snapshot(length(x))
  expect_snapshot(dim(x))
  expect_snapshot(
    tf$reshape(x, tensorflow::as_tensor(shape(-1, 1, 1L, 1, 1L)))
  )
})

test_that("shape returns right thing", {
  skip_if_not(check_tf_version())
  expect_snapshot(shape())
  expect_snapshot(shape(NULL))
  expect_snapshot(shape(NA))
  expect_snapshot(shape(dims = NULL))
  expect_snapshot(shape(3, 4))
  expect_snapshot(shape(NA, 4))
  expect_snapshot(shape(dims = c(NA, 4)))
  expect_snapshot(shape(1,1,1))
  expect_equal(dim(shape()), NULL)
  expect_equal(dim(shape(NULL)), NULL)
  expect_equal(dim(shape(NA)), NULL)
  expect_equal(dim(shape(dims = NULL)), NULL)
  expect_equal(dim(shape(3, 4)), NULL)
  expect_equal(dim(shape(NA, 4)), NULL)
  expect_equal(dim(shape(dims = c(NA, 4))), NULL)
})

# TF1/2 check todo
# do we care about something equivalent to this?
# test_that("placeholder and friends behave the same way", {
#   skip_if_not(check_tf_version())
#
#   x <- tf$compat$v1$placeholder(tf$float64, list(2, NULL))
#   y <- x * 2
#   sess <- tf$compat$v1$Session()
#   dict <- dict(x = array(1, c(2, 10)))
#
#   expect_snapshot_output(dim(x))
#   expect_snapshot_output(length(x))
#   expect_snapshot_output(dim(y))
#   expect_snapshot_output(length(y))
#   expect_snapshot_output(length(dict))
#   expect_snapshot_output(dim(dict))
#   expect_snapshot_output(length(sess$run(y, feed_dict = dict)))
#   expect_snapshot_output(dim(sess$run(y, feed_dict = dict)))
#
#   dict <- dict(x = array(1, c(2, 4)))
#
#   expect_snapshot_output(dim(dict))
#   expect_snapshot_output(length(dict))
#   expect_snapshot_output(length(sess$run(y, feed_dict = dict)))
#   expect_snapshot_output(dim(sess$run(y, feed_dict = dict)))
#
# })

test_that("TensorShape conversions remain stable", {
  skip_if_not(check_tf_version())
  x <- shape(NA, 3)
  expect_snapshot(as.list(x))
  expect_snapshot(as.integer(x))
  expect_snapshot(tensorflow::as_tensor(x))
  expect_snapshot(x[[1]])
  expect_snapshot(x[[2]])
})

test_that("shape returns appropriate TensorShape object", {
  skip_if_not(check_tf_version())
  expect_snapshot(shape())
  expect_snapshot(shape(NULL))
  expect_snapshot(shape(NA))
  expect_snapshot(shape(dims = NULL))
  expect_snapshot(shape(3, 4))
  expect_snapshot(shape(NA, 4))
  expect_snapshot(shape(dims = c(NA, 4)))
  expect_snapshot(c(shape(1), 3))
  expect_snapshot(length(shape(1)))
  expect_snapshot(length(shape(1, 3)))
  expect_snapshot(as.integer(shape(1, 3)))
  expect_snapshot(as.numeric(shape(1, 3)))
  expect_snapshot(as.double(shape(1, 3)))
  expect_snapshot(shape(1, 3) == shape(1,3))
  expect_snapshot(shape(1, 3) == shape(1,2))
  expect_snapshot(shape(1, 3) != shape(1,3))
  expect_snapshot(shape(1, 3) != shape(1,2))
})

test_that("[, [[, and assignment returns right object", {
  skip_if_not(check_tf_version())
  x_extract <- shape(1,2,3)
  expect_snapshot(x_extract[1])
  expect_snapshot(x_extract[[1]])
  expect_snapshot(x_extract[2:3])
  expect_snapshot(x_extract[-1])
  expect_snapshot({
    x_extract[1] <- 11
    x_extract[1]
    })
  expect_snapshot({
    x_extract[1] <- shape(11)
    x_extract[1]
    })
  expect_snapshot({
    x_extract[1] <- list(11)
    x_extract[1]
    })
})


# other parts to test:
# batch_size <- tf$shape(x)[[0]]
# shape_list <- c(list(batch_size), as.integer(to_shape(dims_out)))
# shape_out <- tf$stack(shape_list)
#
# tf$reshape(ref[, idx, ], tensorflow::as_tensor(shape(-1, length(idx), 1)))
#
# update_list <- lapply(run_id, function(i) {
#   tf$reshape(updates[, i - 1, ], tensorflow::as_tensor(shape(-1, 1, 1)))
# })
#
# has_batch(x)
#
#
# test_that("multiplication works", {
#   expect_equal(2 * 2, 4)
# })
greta-dev/greta documentation built on Dec. 21, 2024, 5:03 a.m.