Training Callbacks

library(keras)
knitr::opts_chunk$set(comment = NA, eval = FALSE)

Overview

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks) to the fit() function. The relevant methods of the callbacks will then be called at each stage of the training.

For example:

library(keras)

# generate dummy training data
data <- matrix(rexp(1000*784), nrow = 1000, ncol = 784)
labels <- matrix(round(runif(1000*10, min = 0, max = 9)), nrow = 1000, ncol = 10)

# create model
model <- keras_model_sequential() 

# add layers and compile
model %>%
  layer_dense(32, input_shape = c(784)) %>%
  layer_activation('relu') %>%
  layer_dense(10) %>%
  layer_activation('softmax') %>% 
  compile(
    loss='binary_crossentropy',
    optimizer = optimizer_sgd(),
    metrics='accuracy'
  )

# fit with callbacks
model %>% fit(data, labels, callbacks = list(
  callback_model_checkpoint("checkpoints.h5"),
  callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1)
))

Built in Callbacks

The following built-in callbacks are available as part of Keras:

`callback_progbar_logger()`

Callback that prints metrics to stdout.

`callback_model_checkpoint()`

Save the model after every epoch.

`callback_early_stopping()`

Stop training when a monitored quantity has stopped improving.

`callback_remote_monitor()`

Callback used to stream events to a server.

`callback_learning_rate_scheduler()`

Learning rate scheduler.

`callback_tensorboard()`

TensorBoard basic visualizations

`callback_reduce_lr_on_plateau()`

Reduce learning rate when a metric has stopped improving.

`callback_csv_logger()`

Callback that streams epoch results to a csv file

`callback_lambda()`

Create a custom callback

Custom Callbacks

You can create a custom callback by creating a new R6 class that inherits from the KerasCallback class.

Here's a simple example saving a list of losses over each batch during training:

library(keras)

# define custom callback class
LossHistory <- R6::R6Class("LossHistory",
  inherit = KerasCallback,

  public = list(

    losses = NULL,

    on_batch_end = function(batch, logs = list()) {
      self$losses <- c(self$losses, logs[["loss"]])
    }
))

# define model
model <- keras_model_sequential() 

# add layers and compile
model %>% 
  layer_dense(units = 10, input_shape = c(784)) %>% 
  layer_activation(activation = 'softmax') %>% 
  compile(
    loss = 'categorical_crossentropy', 
    optimizer = 'rmsprop'
  )

# create history callback object and use it during training
history <- LossHistory$new()
model %>% fit(
  X_train, Y_train,
  batch_size=128, epochs=20, verbose=0,
  callbacks= list(history)
)

# print the accumulated losses
history$losses
[1] 0.6604760 0.3547246 0.2595316 0.2590170 ...

Fields

Custom callback objects have access to the current model and it's training parameters via the following fields:

self$params

: Named list with training parameters (eg. verbosity, batch size, number of epochs...).

self$model

: Reference to the Keras model being trained.

Methods

Custom callback objects can implement one or more of the following methods:

on_epoch_begin(epoch, logs)

: Called at the beginning of each epoch.

on_epoch_end(epoch, logs)

: Called at the end of each epoch.

on_batch_begin(batch, logs)

: Called at the beginning of each batch.

on_batch_end(batch, logs)

: Called at the end of each batch.

on_train_begin(logs)

: Called at the beginning of training.

on_train_end(logs)

: Called at the end of training.

on_train_batch_begin

: Called at the beginning of every batch.

on_train_batch_end

: Called at the end of every batch.`

on_predict_batch_begin

: Called at the beginning of a batch in predict methods.

on_predict_batch_end

: Called at the end of a batch in predict methods.

on_predict_begin

: Called at the beginning of prediction.

on_predict_end

: Called at the end of prediction.

on_test_batch_begin

: Called at the beginning of a batch in evaluate methods. Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

on_test_batch_end

: Called at the end of a batch in evaluate methods. Also called at the end of a validation batch in the fit methods, if validation data is provided.

on_test_begin

: Called at the beginning of evaluation or validation.

on_test_end

: Called at the end of evaluation or validation.



Try the keras package in your browser

Any scripts or data that you put into this service are public.

keras documentation built on May 29, 2024, 3:20 a.m.