Nothing
#' Tensor shape utility
#'
#' This function can be used to get or create a tensor shape.
#'
#' # Examples
#' ```{r}
#' shape(1, 2, 3)
#' ```
#'
#' 3 ways to specify an unknown dimension
#' ```{r, results = "hold"}
#' shape(NA, 2, 3)
#' shape(NULL, 2, 3)
#' shape(-1, 2, 3)
#' ```
#'
#' Most functions that take a 'shape' argument also coerce with `shape()`
#' ```{r, results = "hold"}
#' layer_input(c(1, 2, 3))
#' layer_input(shape(1, 2, 3))
#' ```
#'
#' You can also use `shape()` to get the shape of a tensor
#' (excepting scalar integer tensors).
#' ```{r}
#' symbolic_tensor <- layer_input(shape(1, 2, 3))
#' shape(symbolic_tensor)
#'
#' eager_tensor <- op_ones(c(1,2,3))
#' shape(eager_tensor)
#' op_shape(eager_tensor)
#' ```
#'
#' Combine or expand shapes
#' ```{r}
#' shape(symbolic_tensor, 4)
#' shape(5, symbolic_tensor, 4)
#' ```
#'
#' Scalar integer tensors are treated as axis values. These are most commonly
#' encountered when tracing a function in graph mode, where an axis size might
#' be unknown.
#' ```{r}
#' tfn <- tensorflow::tf_function(function(x) {
#' print(op_shape(x))
#' x
#' },
#' input_signature = list(tensorflow::tf$TensorSpec(shape(1, NA, 3))))
#' invisible(tfn(op_ones(shape(1, 2, 3))))
#' ```
#'
#' A useful pattern is to unpack the `shape()` with `%<-%`, like this:
#' ```r
#' c(batch_size, seq_len, channels) %<-% shape(x)
#'
#' # `%<-%` also has support for skipping values
#' # during unpacking with `.` and `...`. For example,
#' # To retrieve just the first and/or last dim:
#' c(batch_size, ...) %<-% shape(x)
#' c(batch_size, ., .) %<-% shape(x)
#' c(..., channels) %<-% shape(x)
#' c(batch_size, ..., channels) %<-% shape(x)
#' c(batch_size, ., channels) %<-% shape(x)
#' ```
#'
#' ```{r}
#' echo_print <- function(x) {
#' message("> ", deparse(substitute(x)));
#' if(!is.null(x)) print(x)
#' }
#' tfn <- tensorflow::tf_function(function(x) {
#' c(axis1, axis2, axis3) %<-% shape(x)
#' echo_print(str(list(axis1 = axis1, axis2 = axis2, axis3 = axis3)))
#'
#' echo_print(shape(axis1)) # use axis1 tensor as axis value
#' echo_print(shape(axis1, axis2, axis3)) # use axis1 tensor as axis value
#'
#' # use shape() to compose a new shape, e.g., in multihead attention
#' n_heads <- 4
#' echo_print(shape(axis1, axis2, n_heads, axis3/n_heads))
#'
#' x
#' },
#' input_signature = list(tensorflow::tf$TensorSpec(shape(NA, 4, 16))))
#' invisible(tfn(op_ones(shape(2, 4, 16))))
#' ```
#'
#' If you want to resolve the shape of a tensor that can potentially be
#' a scalar integer, you can wrap the tensor in `I()`, or use [`op_shape()`].
#' ```{r}
#' (x <- op_convert_to_tensor(2L))
#'
#' # by default, shape() treats scalar integer tensors as axis values
#' shape(x)
#'
#' # to access the shape of a scalar integer,
#' # call `op_shape()`, or protect with `I()`
#' op_shape(x)
#' shape(I(x))
#' ```
#'
#' @param ... A shape specification. Numerics, `NULL` and tensors are valid.
#' `NULL`, `NA`, and `-1L` can be used to specify an unspecified dim size.
#' Tensors are dispatched to `op_shape()` to extract the tensor shape. Values
#' wrapped in `I()` are used asis (see examples). All other objects are coerced
#' via `as.integer()`.
#'
#' @returns A list with a `"keras_shape"` class attribute. Each element of the
#' list will be either a) `NULL`, b) an R integer or c) a scalar integer tensor
#' (e.g., when supplied a TF tensor with an unspecified dimension in a function
#' being traced).
#'
#' @export
#' @seealso [op_shape()]
shape <- function(...) {
fix <- function(x) {
if (is_py_object(x)) {
if (inherits(x, "tensorflow.python.framework.tensor_shape.TensorShape"))
return(map_int(as.list(as_r_value(x$as_list())),
function(e) e %||% NA_integer_))
shp <- keras$ops$shape(x)
# convert subclassed tuples, as encountered in Torch
# class(shp): torch.Size, python.builtin.tuple, python.builtin.object
if(inherits(shp, "python.builtin.tuple"))
shp <- import("builtins")$tuple(shp)
# scalar integer tensors, unprotected with I(), are treated as an axis value
if (identical(shp, list()) && keras$backend$is_int_dtype(x$dtype)) {
if (!inherits(x, "AsIs"))
return(x)
}
# otherwise, (most common path) shape() is a tensor shape accessor
return(lapply(shp, function(d) d %||% NA_integer_))
}
if(is.array(x))
return(dim(x))
if (is.null(x) ||
identical(x, NA_integer_) ||
identical(x, NA_real_) ||
identical(x, NA) ||
(is.numeric(x) && isTRUE(suppressWarnings(x == -1L))))
return(NA_integer_) # so we can safely unlist()
if (!is.atomic(x) || length(x) > 1)
return(lapply(x, fix)) # recurse
as.integer(x)
}
shp <- unlist(lapply(list(...), fix), use.names = FALSE)
shp <- lapply(shp, function(x) if (identical(x, NA_integer_)) NULL else x)
class(shp) <- "keras_shape"
shp
}
#' @export
#' @rdname shape
#' @param x,y A `keras_shape` object.
#' @param prefix Whether to format the shape object with a prefix. Defaults to
#' `"shape"`.
format.keras_shape <- function(x, ..., prefix = TRUE) {
x <- vapply(x, function(d) format(d %||% "NA"), "")
x <- paste0(x, collapse = ", ")
if(isTRUE(prefix))
prefix <- "shape"
else if (!is_string(prefix))
prefix <- ""
paste0(prefix, "(", x, ")")
}
#' @export
#' @rdname shape
print.keras_shape <- function(x, ...) {
writeLines(format(x, ...))
invisible(x)
}
#' @rdname shape
#' @export
`[.keras_shape` <- function(x, ...) {
out <- unclass(x)[...]
class(out) <- class(x)
out
}
#' @export
r_to_py.keras_shape <- function(x, convert = FALSE) {
tuple(x, convert = convert)
}
#' @rdname shape
#' @export
as.integer.keras_shape <- function(x, ...) {
vapply(x, function(el) el %||% NA_integer_, 1L)
}
#' @importFrom zeallot destructure
#' @export
destructure.keras_shape <- function(x) unclass(x)
#' @rdname shape
#' @export
as.list.keras_shape <- function(x, ...) unclass(x)
#' @rdname shape
#' @export
`==.keras_shape` <- function(x, y) {
if(!inherits(x, "keras_shape"))
x <- shape(x)
if(!inherits(y, "keras_shape"))
y <- shape(y)
identical(x, y)
}
#' @rdname shape
#' @export
`!=.keras_shape` <- function(x, y) {
!`==.keras_shape`(x, y)
}
# ' @rdname shape
# ' @export
# c.keras_shape <- function(...) shape(...)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.