library(keras) knitr::opts_chunk$set(comment = NA, eval = FALSE)
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks
) to the fit()
function. The relevant methods of the callbacks will then be called at each stage of the training.
For example:
library(keras) # generate dummy training data data <- matrix(rexp(1000*784), nrow = 1000, ncol = 784) labels <- matrix(round(runif(1000*10, min = 0, max = 9)), nrow = 1000, ncol = 10) # create model model <- keras_model_sequential() # add layers and compile model %>% layer_dense(32, input_shape = c(784)) %>% layer_activation('relu') %>% layer_dense(10) %>% layer_activation('softmax') %>% compile( loss='binary_crossentropy', optimizer = optimizer_sgd(), metrics='accuracy' ) # fit with callbacks model %>% fit(data, labels, callbacks = list( callback_model_checkpoint("checkpoints.h5"), callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1) ))
The following built-in callbacks are available as part of Keras:
`callback_progbar_logger()` | Callback that prints metrics to stdout. |
`callback_model_checkpoint()` | Save the model after every epoch. |
`callback_early_stopping()` | Stop training when a monitored quantity has stopped improving. |
`callback_remote_monitor()` | Callback used to stream events to a server. |
`callback_learning_rate_scheduler()` | Learning rate scheduler. |
`callback_tensorboard()` | TensorBoard basic visualizations |
`callback_reduce_lr_on_plateau()` | Reduce learning rate when a metric has stopped improving. |
`callback_csv_logger()` | Callback that streams epoch results to a csv file |
`callback_lambda()` | Create a custom callback |
You can create a custom callback by creating a new R6 class that inherits from the KerasCallback
class.
Here's a simple example saving a list of losses over each batch during training:
library(keras) # define custom callback class LossHistory <- R6::R6Class("LossHistory", inherit = KerasCallback, public = list( losses = NULL, on_batch_end = function(batch, logs = list()) { self$losses <- c(self$losses, logs[["loss"]]) } )) # define model model <- keras_model_sequential() # add layers and compile model %>% layer_dense(units = 10, input_shape = c(784)) %>% layer_activation(activation = 'softmax') %>% compile( loss = 'categorical_crossentropy', optimizer = 'rmsprop' ) # create history callback object and use it during training history <- LossHistory$new() model %>% fit( X_train, Y_train, batch_size=128, epochs=20, verbose=0, callbacks= list(history) ) # print the accumulated losses history$losses
[1] 0.6604760 0.3547246 0.2595316 0.2590170 ...
Custom callback objects have access to the current model and it's training parameters via the following fields:
self$params
: Named list with training parameters (eg. verbosity, batch size, number of epochs...).
self$model
: Reference to the Keras model being trained.
Custom callback objects can implement one or more of the following methods:
on_epoch_begin(epoch, logs)
: Called at the beginning of each epoch.
on_epoch_end(epoch, logs)
: Called at the end of each epoch.
on_batch_begin(batch, logs)
: Called at the beginning of each batch.
on_batch_end(batch, logs)
: Called at the end of each batch.
on_train_begin(logs)
: Called at the beginning of training.
on_train_end(logs)
: Called at the end of training.
on_train_batch_begin
: Called at the beginning of every batch.
on_train_batch_end
: Called at the end of every batch.`
on_predict_batch_begin
: Called at the beginning of a batch in predict methods.
on_predict_batch_end
: Called at the end of a batch in predict methods.
on_predict_begin
: Called at the beginning of prediction.
on_predict_end
: Called at the end of prediction.
on_test_batch_begin
: Called at the beginning of a batch in evaluate methods. Also called at the beginning of a validation batch in the fit methods, if validation data is provided.
on_test_batch_end
: Called at the end of a batch in evaluate methods. Also called at the end of a validation batch in the fit methods, if validation data is provided.
on_test_begin
: Called at the beginning of evaluation or validation.
on_test_end
: Called at the end of evaluation or validation.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.