R/layer.R

Defines functions as_nullable_integer layer_hub

Documented in layer_hub

#' Hub Layer
#'
#' Wraps a Hub module (or a similar callable) for TF2 as a Keras Layer.
#'
#' This layer wraps a callable object for use as a Keras layer. The callable
#' object can be passed directly, or be specified by a string with a handle
#' that gets passed to `hub_load()`.
#'
#' The callable object is expected to follow the conventions detailed below.
#' (These are met by TF2-compatible modules loaded from TensorFlow Hub.)
#'
#' The callable is invoked with a single positional argument set to one tensor or
#' a list of tensors containing the inputs to the layer. If the callable accepts
#' a training argument, a boolean is passed for it. It is `TRUE` if this layer
#' is marked trainable and called for training.
#'
#' If present, the following attributes of callable are understood to have special
#' meanings: variables: a list of all tf.Variable objects that the callable depends on.
#' trainable_variables: those elements of variables that are reported as trainable
#' variables of this Keras Layer when the layer is trainable. regularization_losses:
#' a list of callables to be added as losses of this Keras Layer when the layer is
#' trainable. Each one must accept zero arguments and return a scalar tensor.
#'
#' @param object Model or layer object
#' @param handle a callable object (subject to the conventions above), or a string
#'   for which `hub_load()` returns such a callable. A string is required to save
#'   the Keras config of this Layer.
#' @param trainable Boolean controlling whether this layer is trainable.
#' @param arguments optionally, a list with additional keyword arguments passed to
#'   the callable. These must be JSON-serializable to save the Keras config of
#'   this layer.
#' @param ... Other arguments that are passed to the TensorFlow Hub module.
#'
#' @examples
#'
#' \dontrun{
#'
#' library(keras)
#'
#' model <- keras_model_sequential() %>%
#'  layer_hub(
#'    handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4",
#'    input_shape = c(224, 224, 3)
#'  ) %>%
#'  layer_dense(1)
#'
#' }
#'
#' @export
layer_hub <- function(object, handle, trainable = FALSE, arguments = NULL, ...) {

  args <- list(...)

  if (!is.null(args$input_shape))
    args$input_shape <- lapply(args$input_shape, as_nullable_integer)

  keras::create_layer(
    tfhub$KerasLayer,
    object,
    append(
      list(
        handle = handle,
        trainable = trainable,
        arguments = arguments
      ),
      args
    )
  )
}

as_nullable_integer <- function(x) {
  if (is.null(x))
    x
  else
    as.integer(x)
}

Try the tfhub package in your browser

Any scripts or data that you put into this service are public.

tfhub documentation built on Dec. 19, 2021, 9:07 a.m.