inst/doc/custom_models.R

## ----setup, include = FALSE---------------------------------------------------
library(keras)
knitr::opts_chunk$set(comment = NA, eval = FALSE)

## -----------------------------------------------------------------------------
#  library(keras)
#  
#  keras_model_simple_mlp <- function(num_classes,
#                                     use_bn = FALSE, use_dp = FALSE,
#                                     name = NULL) {
#  
#    # define and return a custom model
#    keras_model_custom(name = name, function(self) {
#  
#      # create layers we'll need for the call (this code executes once)
#      self$dense1 <- layer_dense(units = 32, activation = "relu")
#      self$dense2 <- layer_dense(units = num_classes, activation = "softmax")
#      if (use_dp)
#        self$dp <- layer_dropout(rate = 0.5)
#      if (use_bn)
#        self$bn <- layer_batch_normalization(axis = -1)
#  
#      # implement call (this code executes during training & inference)
#      function(inputs, mask = NULL, training = FALSE) {
#        x <- self$dense1(inputs)
#        if (use_dp)
#          x <- self$dp(x)
#        if (use_bn)
#          x <- self$bn(x)
#        self$dense2(x)
#      }
#    })
#  }

## -----------------------------------------------------------------------------
#  library(keras)
#  
#  # create the model
#  model <- keras_model_simple_mlp(num_classes = 10, use_dp = TRUE)
#  
#  # compile graph
#  model %>% compile(
#    loss = 'categorical_crossentropy',
#    optimizer = optimizer_rmsprop(),
#    metrics = c('accuracy')
#  )
#  
#  # Generate dummy data
#  data <- matrix(runif(1000*100), nrow = 1000, ncol = 100)
#  labels <- matrix(round(runif(1000, min = 0, max = 9)), nrow = 1000, ncol = 1)
#  
#  # Convert labels to categorical one-hot encoding
#  one_hot_labels <- to_categorical(labels, num_classes = 10)
#  
#  # Train the model
#  model %>% fit(data, one_hot_labels, epochs=10, batch_size=32)

Try the keras package in your browser

Any scripts or data that you put into this service are public.

keras documentation built on Aug. 16, 2023, 1:07 a.m.