R/model-custom.R

Defines functions summary.kerastools.model.RModel print.kerastools.model.RModel keras_model_custom

Documented in keras_model_custom

#' (Deprecated) Create a Keras custom model
#'
#' `keras_model_custom()` is soft-deprecated. Please define custom models by
#' subclassing `keras$Model` directly using [`%py_class%`] or [`R6::R6Class()`],
#' or by calling `new_model_class()`.
#'
#' @param model_fn Function that returns an R custom model
#' @param name Optional name for model
#'
#' @return A Keras model
#'
#' @keywords internal
#' @export
keras_model_custom <- function(model_fn, name = NULL) {

  # verify version
  if (is_tensorflow_implementation() && keras_version() < "2.1.6")
    stop("Custom models require TensorFlow v1.9 or higher")
  else if (!is_tensorflow_implementation() && keras_version() < "2.2.0")
    stop("Custom models require Keras v2.2 or higher")

  # create the python subclass
  python_path <- system.file("python", package = "keras")
  tools <- import_from_path("kerastools", path = python_path)
  model <- tools$model$RModel(name = name)

  # call the R model function
  r_model_call <- model_fn(model)

  # set the _r_call for delegation
  model$`_r_call` <- r_model_call

  # return model
  model
}

#' @export
print.kerastools.model.RModel <- function(x, ...) {
  if (!x$built) {
    cat("Custom Keras model: not yet fitted")
    return(invisible(x))
  }
  NextMethod()
}

#' @export
summary.kerastools.model.RModel <- function(object, ...) {
  if (!object$built) {
    cat("This custom model has not yet been built. To see a summary, compile and fit with some data.")
    return(invisible(NULL))
  }
  NextMethod()
}

Try the keras package in your browser

Any scripts or data that you put into this service are public.

keras documentation built on Aug. 16, 2023, 1:07 a.m.