R/utils.R

Defines functions capture_args assert_all_dots_named require_tf_version as_integer_list as_py_function is_eager_tensor is_tensor is_dataset column_names validate_tf_version with_session as_tensor_shapes as_integer_tensor

as_integer_tensor <- function(x, dtype = tf$int64) {

  # recurse over lists
  if (is.list(x) || (is.numeric(x) && length(x) > 1))
    lapply(x, function(elem) as_integer_tensor(elem, dtype))
  else if (is.null(x))
    x
  else if (is_tensor(x))
    tf$cast(x, dtype = dtype)
  else
    as.integer(x)
    # https://github.com/tensorflow/tensorflow/issues/71937
    # tf$constant(as.integer(x), dtype = dtype)
}

as_tensor_shapes <- function(x) {
  if (is.list(x))
    tuple(lapply(x, as_tensor_shapes))
  else if (is_tensor(x))
    tf$cast(x, dtype = tf$int64)
  else if (inherits(x, "python.builtin.object"))
    x
  else if (is.null(x))
    tf$constant(-1L, dtype = tf$int64)
  else
    tf$constant(as.integer(x), dtype = tf$int64)
}

with_session <- function(f, session = NULL) {
  if (is.null(session)) {
    if (tensorflow::tf_version() >= "1.14")
      session <- tensorflow::tf$compat$v1$get_default_session()
  } else {
    session <- tf$get_default_session()
  }
  if (is.null(session)) {
    if (tensorflow::tf_version() >= "1.14")
      session <- tf$compat$v1$Session()
    else
      session <- tf$Session()
    on.exit(session$close(), add = TRUE)
  }
  f(session)
}


validate_tf_version <- function(required_ver = "1.4", feature_name = "tfdatasets") {
  tf_ver <- tensorflow::tf_version()
  if (is.null(tf_ver)) {
    stop("You need to install TensorFlow to use tfdatasets ",
         "-- install with tensorflow::install_tensorflow()",
         call. = FALSE)
  } else if (tf_ver < required_ver) {
    stop(
      feature_name, " requires version ", required_ver, " ",
      "of TensorFlow (you are currently running version ", tf_ver, ").",
      call. = FALSE
    )
  }
}

column_names <- function(dataset) {

  if (tensorflow::tf_version() >= "2.0") {
    x <- next_batch(dataset)
  } else {
    x <- dataset$output_shapes
  }

  if (!is.list(x) || is.null(names(x)))
    stop("Unable to resolve features for dataset that does not have named outputs", call. = FALSE)

  names(x)
}

is_dataset <- function(x) {
  inherits(x, "tensorflow.python.data.ops.dataset_ops.Dataset") ||
  inherits(x, "tensorflow.python.data.ops.dataset_ops.DatasetV2")
}

is_tensor <- function(x) {
  inherits(x, "tensorflow.python.framework.ops.Tensor")
}

is_eager_tensor <- function(x) {
  inherits(x, "python.builtin.EagerTensor") ||
  inherits(x, "tensorflow.python.framework.ops.EagerTensor")
}

as_py_function <- function(x) {
  if (inherits(x, "python.builtin.function")) {
    x
  } else {
    rlang::as_function(x)
  }
}


as_integer_list <- function(x) as.list(as.integer(x))


# assert_all_dots_named(), capture_args(), require_tf_version()
# copy-pasted from keras circa tf_version() 2.7
require_tf_version <- function(ver, msg = "this function.") {
  if (tf_version() < ver)
    stop("Tensorflow version >=", ver, " required to use ", msg)
}

assert_all_dots_named <- function(envir = parent.frame(), cl) {

  x <- eval(quote(list(...)), envir)
  if(!length(x))
    return()

  x <- names(x)
  if(is.character(x) && !anyNA(x) && all(x != ""))
    return()

  stop("All arguments provided to `...` must be named.\n",
       "Call with unnamed arguments in dots:\n  ",
       paste(deparse(cl, 500L), collapse = "\n"))
}

# TODO: should there be some default modifiers in capture_args() for standard layer args
# like, input_shape, batch_input_shape, etc.

capture_args <- function(cl, modifiers = NULL, ignore = NULL,
                         envir = parent.frame(), fn = sys.function(-1)) {

  ## bug: match.call() resolves incorrectly if dots are from not the default sys.parent()
  ## e.g, this fails if dots originate from the callers caller:
  #    cl <- eval(quote(match.call()), parent.frame())
  ## workaround: caller must call match.call() from the correct frame

  ## note: capture_args() must always be called at the top level of the intended function body.
  ## sys.function(-1) resolves to the incorrect function if the  capture_args()
  ## call is itself a promise in another call. E.g.,:
  ##   do.call(foo, capture_args(match.call())) fails because fn resolves to do.call()

  fn_arg_nms <- names(formals(fn))
  known_args <- intersect(names(cl), fn_arg_nms)
  known_args <- setdiff(known_args, ignore)
  names(known_args) <- known_args
  cl2 <- c(quote(list), lapply(known_args, as.symbol))

  if("..." %in% fn_arg_nms && !"..." %in% ignore) {
    assert_all_dots_named(envir, cl)
    # this might reorder args by assuming ... are last, but it doesn't matter
    # since everything is supplied as a keyword arg to the Python side anyway
    cl2 <- c(cl2, quote(...))
  }

  args <- eval(as.call(cl2), envir)

  for(nm in intersect(names(args), ignore))
    args[[nm]] <- NULL

  nms_to_modify <- intersect(names(args), names(modifiers))
  for (nm in nms_to_modify)
    args[nm] <- list(modifiers[[nm]](args[[nm]]))
   # list() so if modifier returns NULL, don't remove the arg

  args
}
rstudio/tfdatasets documentation built on July 22, 2024, 12:41 a.m.