library(keras) knitr::opts_chunk$set(comment = NA, eval = FALSE)
In addition to sequential models and models created with the functional API, you may also define models by defining a custom call()
(forward pass) operation.
To create a custom Keras model, you call the keras_model_custom()
function, passing it an R function which in turn returns another R function that implements the custom call()
(forward pass) operation. The R function you pass takes a model
argument, which provides access to the underlying Keras model object should you need it.
Typically, you'll wrap your call to keras_model_custom()
in yet another function that enables callers to easily instantiate your custom model.
This example demonstrates the implementation of a simple custom model that implements a multi-layer-perceptron with optional dropout and batch normalization:
library(keras) keras_model_simple_mlp <- function(num_classes, use_bn = FALSE, use_dp = FALSE, name = NULL) { # define and return a custom model keras_model_custom(name = name, function(self) { # create layers we'll need for the call (this code executes once) self$dense1 <- layer_dense(units = 32, activation = "relu") self$dense2 <- layer_dense(units = num_classes, activation = "softmax") if (use_dp) self$dp <- layer_dropout(rate = 0.5) if (use_bn) self$bn <- layer_batch_normalization(axis = -1) # implement call (this code executes during training & inference) function(inputs, mask = NULL, training = FALSE) { x <- self$dense1(inputs) if (use_dp) x <- self$dp(x) if (use_bn) x <- self$bn(x) self$dense2(x) } }) }
Note that we include a name
parameter so that users can optionally provide a human readable name for the model.
Note also that when we create layers to be used in our forward pass we set them onto the self
object so they are tracked appropriately by Keras.
In call()
, you may specify custom losses by calling self$add_loss()
. You can also access any other members of the Keras model you need (or even add fields to the model) by using self$
.
To use a custom model, just call your model's high-level wrapper function. For example:
library(keras) # create the model model <- keras_model_simple_mlp(num_classes = 10, use_dp = TRUE) # compile graph model %>% compile( loss = 'categorical_crossentropy', optimizer = optimizer_rmsprop(), metrics = c('accuracy') ) # Generate dummy data data <- matrix(runif(1000*100), nrow = 1000, ncol = 100) labels <- matrix(round(runif(1000, min = 0, max = 9)), nrow = 1000, ncol = 1) # Convert labels to categorical one-hot encoding one_hot_labels <- to_categorical(labels, num_classes = 10) # Train the model model %>% fit(data, one_hot_labels, epochs=10, batch_size=32)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.