R/layer-custom.R

Defines functions r_to_py.keras_layer_wrapper create_layer_wrapper py_formals compat_custom_KerasLayer_handler

Documented in create_layer_wrapper

#' (Deprecated) Base R6 class for Keras layers
#'
#' Custom R6 layers can now inherit directly from `keras$layers$Layer` or other layers.
#'
#' @docType class
#'
#' @format An [R6Class] generator object #'
#' @section Methods: \describe{ \item{\code{build(input_shape)}}{Creates the
#'   layer weights (must be implemented by all layers that have weights)}
#'   \item{\code{call(inputs,mask)}}{Call the layer on an input tensor.}
#'   \item{\code{compute_output_shape(input_shape)}}{Compute the output shape
#'   for the layer.}
#'   \item{\code{add_loss(losses, inputs)}}{Add losses to the layer.}
#'   \item{\code{add_weight(name,shape,dtype,initializer,regularizer,trainable,constraint)}}{Adds
#'   a weight variable to the layer.} }
#'
#' @return [KerasLayer].
#'
#' @export
KerasLayer <- R6Class("KerasLayer",

  public = list(

    # Create the layer weights.
    build = function(input_shape) {

    },

    # Call the layer on an input tensor.
    call = function(inputs, mask = NULL) {
      stop("Keras custom layers must implement the call function")
    },

    # Compute the output shape for the layer.
    compute_output_shape = function(input_shape) {
      input_shape
    },

    # Add losses to the layer
    add_loss = function(losses, inputs = NULL) {
      args <- list()
      args$losses <- losses
      args$inputs <- inputs
      do.call(private$wrapper$add_loss, args)
    },

    # Adds a weight variable to the layer.
    add_weight = function(name, shape, dtype = NULL, initializer = NULL,
                          regularizer = NULL, trainable = TRUE, constraint = NULL) {

      args <- list()
      args$name <- name
      args$shape <- shape
      args$dtype <- dtype
      args$initializer <- initializer
      args$regularizer <- regularizer
      args$trainable <- trainable
      args$constraint <- constraint

      do.call(private$wrapper$add_weight, args)
    },

    # back reference to python layer that wraps us
    .set_wrapper = function(wrapper) {
      private$wrapper <- wrapper
    },

    python_layer = function() {
      private$wrapper
    }
  ),

  active = list(
    input = function(value) {
      if (missing(value)) return(private$wrapper$input)
      else private$wrapper$input <- value
    },
    output = function(value) {
      if (missing(value)) return(private$wrapper$output)
      else private$wrapper$output <- value
    }
  ),

  private = list(
    wrapper = NULL
  )
)


compat_custom_KerasLayer_handler <- function(layer_class, args) {
    # common layer parameters (e.g. "input_shape") need to be passed to the
    # Python Layer constructor rather than the R6 constructor. Here we
    # extract and set aside any of those arguments we find and set them to
    # NULL within the args list which will be passed to the R6 layer
    common_arg_names <- c("input_shape", "batch_input_shape", "batch_size",
                          "dtype", "name", "trainable", "weights")

    py_wrapper_args <- args[common_arg_names]
    py_wrapper_args[sapply(py_wrapper_args, is.null)] <- NULL
    for (arg in names(py_wrapper_args))
      args[[arg]] <- NULL

    # create the R6 layer
    r6_layer <- do.call(layer_class$new, args)

    # create the python wrapper (passing the extracted py_wrapper_args)
    python_path <- system.file("python", package = "keras")
    tools <- import_from_path("kerastools", path = python_path)
    py_wrapper_args$r_build <- r6_layer$build
    py_wrapper_args$r_call <-  reticulate::py_func(r6_layer$call)
    py_wrapper_args$r_compute_output_shape <- r6_layer$compute_output_shape
    layer <- do.call(tools$layer$RLayer, py_wrapper_args)

    # set back reference in R layer
    r6_layer$.set_wrapper(layer)
    list(layer, args)
}




py_formals <- function(py_obj) {
  # returns python fn formals as a list (formals(),
  # but for py functions/methods
  inspect <- reticulate::import("inspect")
  sig <- if (inspect$isclass(py_obj)) {
    inspect$signature(py_obj$`__init__`)
  } else
    inspect$signature(py_obj)

  args <- pairlist()
  it <- sig$parameters$items()$`__iter__`()
  repeat {
    x <- reticulate::iter_next(it)
    if (is.null(x))
      break

    name <- x[[1]]
    param <- x[[2]]


    if (param$kind == inspect$Parameter$VAR_KEYWORD ||
        param$kind == inspect$Parameter$VAR_POSITIONAL) {
      args[["..."]] <- quote(expr = )
      next
    }

    default <- param$default

    if (inherits(default, "python.builtin.object")) {
      if (default != inspect$Parameter$empty)
        # must be something complex that failed to convert
        warning(glue::glue(
          "Failed to convert default arg {param} for {name} in {py_obj_expr}"
        ))
      args[[name]] <- quote(expr = )
      next
    }

    args[[name]] <- default
  }
  args
}




#' Create a Keras Layer wrapper
#'
#' @param LayerClass A R6 or Python class generator that inherits from
#'   `keras$layers$Layer`
#' @param modifiers A named list of functions to modify to user-supplied
#'   arguments before they are passed on to the class constructor. (e.g.,
#'   `list(units = as.integer)`)
#' @param convert Boolean, whether the Python class and its methods should by
#'   default convert python objects to R objects.
#'
#' See guide 'making_new_layers_and_models_via_subclassing.Rmd' for example usage.
#'
#' @return An R function that behaves similarly to the builtin keras `layer_*`
#'   functions. When called, it will create the class instance, and also
#'   optionally call it on a supplied argument `object` if it is present. This
#'   enables keras layers to compose nicely with the pipe (`%>%`).
#'
#'   The R function will arguments taken from the `initialize` (or `__init__`)
#'   method of the LayerClass.
#'
#'   If `LayerClass` is an R6 object, this will avoid initializing the python
#'   session, so it is safe to use in an R package.
#'
#' @export
#' @importFrom rlang %||%
create_layer_wrapper <- function(LayerClass, modifiers=NULL, convert=TRUE) {

  LayerClass_in <- LayerClass

  force(modifiers)
  wrapper <- function(object) {
    args <- capture_args(match.call(), modifiers)
    args$object <- NULL
    create_layer(LayerClass, object, args)
  }

  formals(wrapper) <- local({
    if (inherits(LayerClass, "python.builtin.type")) {
      f <- py_formals(LayerClass)
    } else {
      # LayerClass is R6
      m <- LayerClass$public_methods
      init <- m$initialize %||% m$`__init__`
      f <- formals(init)
    }
    f$self <- NULL
    c(formals(wrapper), f)
  })

  # create_layer() will call r_to_py() as needed, but we create a promise here
  # to avoid creating the class constructor from scratch every time a class
  # instance is created.
  if (!inherits(LayerClass, "python.builtin.type"))
    delayedAssign("LayerClass", r_to_py(LayerClass_in, convert))

  class(wrapper) <- c("keras_layer_wrapper", "function")
  attr(wrapper, "Layer") <- LayerClass_in

  wrapper
}


#' @export
r_to_py.keras_layer_wrapper <- function(fn, convert = FALSE) {
  layer <- attr(fn, "Layer", TRUE)
  if (!inherits(layer, "python.builtin.type"))
    layer <- r_to_py(layer, convert)
  layer
}

Try the keras package in your browser

Any scripts or data that you put into this service are public.

keras documentation built on Aug. 21, 2021, 9:07 a.m.