test_that("shape() works", {
skip_if_no_tensorflow()
expect_tensor_shape <- function(x, dims) {
expect_s3_class(x, "tensorflow.python.framework.tensor_shape.TensorShape")
if (missing(dims))
return()
if (is.null(dims)) {
expect_null(x$rank)
expect_null(x$dims)
} else {
expect_identical(as.list(x$as_list()), dims)
}
}
x <- shape(1, NA, 2, NULL, 3)
expect_tensor_shape(x, list(1L, NULL, 2L, NULL, 3L))
expect_identical(as.list(x), list(1L, NULL, 2L, NULL, 3L))
expect_true(shape() == tf$TensorShape(list()))
expect_true(shape(dims = NULL) == tf$TensorShape(NULL))
# --- construct ---
expect_tensor_shape(shape() , list())
expect_tensor_shape(shape(NULL) , list(NULL))
expect_tensor_shape(shape(dims = NULL) , NULL)
expect_tensor_shape(shape(3, 4) , list(3L, 4L))
expect_tensor_shape(shape(NA, 4) , list(NULL, 4L))
expect_tensor_shape(shape(dims = c(NA, 4)) , list(NULL, 4L))
# --- inspect ---
expect_identical(length(shape(dims = NULL)), NA_integer_, )
expect_identical(length(shape(1, 2, 3, NA)), 4L)
# ---convert ---
x <- shape(dims = list(3L, 5L))
expect_identical(as.list(x) , list(3L, 5L))
expect_identical(as.integer(x), c(3L, 5L))
expect_identical(as.numeric(x), c(3, 5))
expect_identical(as.double(x) , c(3, 5))
x <- shape(NA, 3)
expect_identical(as.list(x), list(NULL, 3L))
expect_identical(as.integer(x), c(NA, 3L))
expect_identical(as.double(x), c(NA, 3))
x2 <- as_tensor(shape(NA, 3))
expect_equal(x2$numpy(), array(c(-1L, 3L)))
expect_identical(x2$dtype$name, "int32")
expect_identical(as.list(x2$shape), list(2L))
x <- shape(dims = NULL)
expect_error(as.list(x))
expect_error(as.numeric(x))
expect_error(as_tensor(x))
x <- shape(NA, 3)
# as_tensor() converts undefined dims to -1
expect_identical(as.integer(as_tensor(x)), c(-1L, 3L))
# can round trips shape -> tensor -> shape
expect_tensor_shape(shape(dims = as_tensor(x)), list(NULL, 3L))
# --- compare ---
# Fully known shapes return TRUE if and only if each element is equal
expect_true(shape(3, 4) == shape(3, 4)) # TRUE
expect_false(shape(3, 4) == shape(4, 4)) # FALSE
# Partially-known shapes always return FALSE
if (tf_version() >= "2.9")
expect_true(shape(NA, 4) == shape(NA, 4))
else
expect_false(shape(NA, 4) == shape(NA, 4))
expect_false(shape(NA, 4) == shape(3, 4))
# Two unknown shapes, return TRUE
expect_true(shape(dims = NULL) == shape(dims = NULL))
# Comparing an unknown shape to a partially or fully defined shape returns FALSE
expect_false(shape(dims = NULL) == shape(NULL))
expect_false(shape(dims = NULL) == shape(4))
if(tf_version() < "2.9") {
# in 2.9, != is just negation of ==
# prior versions: != is mostly the inverse of ==, with one difference:
# it raises an error when comparing a fully unknown shapes
expect_error(shape(dims = NULL) != shape(dims = NULL)) # ValueError: The inequality of unknown TensorShapes is undefined.
expect_error(shape(dims = NULL) != shape()) # ValueError: The inequality of unknown TensorShapes is undefined.
}
# --- extract or replace ---
# regular R-list semantics for `[`, `[[`, `[<-`, `[[<-`
x <- shape(1, 2, 3)
expect_tensor_shape(x[1], list(1L))
expect_identical(x[[1]], 1L)
x_slice <- x[2:3]
expect_tensor_shape(x_slice, list(2L, 3L))
expect_true(x_slice == c(2, 3))
expect_true(x_slice == x[-1])
x <- shape(1, 2, 3)
x[1] <- 11
expect_tensor_shape(x, list(11L, 2L, 3L))
expect_true(x == c(11, 2, 3))
x[1] <- shape(22)
expect_tensor_shape(x, list(22L, 2L, 3L))
expect_true(x == c(22, 2, 3))
x[1] <- list(33)
expect_tensor_shape(x, list(33L, 2L, 3L))
expect_true(x == c(33, 2, 3))
x[[1]] <- 44
expect_true(x == c(44, 2, 3))
x[1:2] <- c(NA, 99)
expect_identical(as.numeric(x), c(NA, 99, 3))
x[1:2] <- shape(33, 44)
expect_tensor_shape(x, list(33L, 44L, 3L))
expect_identical(as.numeric(x), c(33, 44, 3))
# --- concatenate ---
x <-
c(shape(1), shape(2, 3), shape(4, NA)) # TensorShape([1, 2, 3, 4, None])
expect_identical(as.list(x), list(1L, 2L, 3L, 4L, NULL))
# --- merge ---
x <- merge(shape(NA, 2),
shape(1 , 2)) # TensorShape([1, 2])
expect_tensor_shape(x, list(1L, 2L))
expect_true(x == c(1, 2))
expect_error(merge(shape(2, 2),
shape(1, 2))) # ValueError: Shapes (2, 2) and (1, 2) are not compatible
expect_output(print(shape(3)), "TensorShape([3])", fixed = TRUE)
expect_output(print(shape(3, NA)), "TensorShape([3, None])", fixed = TRUE)
expect_output(print(shape(3, NULL)), "TensorShape([3, None])", fixed = TRUE)
expect_equal(format(shape(3)), "(3)")
expect_equal(format(shape(3, NA)), "(3, NA)")
expect_equal(format(shape(3, NULL)), "(3, NA)")
# shape() can accept tf.TensorShapes, and flatten them
expect_equal(as.list(shape(shape(3))), list(3L))
expect_equal(as.list(shape(shape(3, 4))), list(3L, 4L))
expect_equal(as.list(shape(shape(3, 4), 5)), list(3L, 4L, 5L))
expect_equal(as.list(shape(NA, shape(3, 4), 5)), list(NULL, 3L, 4L, 5L))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.