new_callback_cyclical_learning_rate: Initiate a new cyclical learning rate scheduler

Description Usage Arguments Details Differences to Python implementation Examples

View source: R/core.R

Description

This callback implements a cyclical learning rate policy (CLR). The method cycles the learning rate between two boundaries with some constant frequency, as detailed in this paper. In addition, the call-back supports scaled learning-rate bandwidths (see section 'Differences to the Python implementation'). Note that this callback is very general as it can be used to specify:

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
new_callback_cyclical_learning_rate(
  base_lr = 0.001,
  max_lr = 0.006,
  step_size = 2000,
  mode = "triangular",
  gamma = 1,
  scale_fn = NULL,
  scale_mode = "cycle",
  patience = Inf,
  factor = 0.9,
  decrease_base_lr = TRUE,
  cooldown = 2,
  verbose = 1
)

Arguments

base_lr

Initial learning rate which is the lower boundary in the cycle.

max_lr

Upper boundary in the cycle. Functionally, it defines the cycle amplitude (max_lr - base_lr). The learning rate at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function.

step_size

Number of training iterations per half cycle. Authors suggest setting step_size 2-8 x training iterations in epoch.

mode

One of "triangular", "triangular2" or "exp_range". Default "triangular". Values correspond to policies detailed above. If scale_fn is not NULL, this argument is ignored.

gamma

Constant in exp_range scaling function: gamma^(cycle iterations)

scale_fn

Custom scaling policy defined by a single argument anonymous function, where 0 <= scale_fn(x) <= 1 for all x >= 0. Mode paramater is ignored.

scale_mode

Either "cycle" or "iterations". Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Default is "cycle".

patience

The number of epochs of training without validation loss improvement that the callback will wait before it adjusts base_lr and max_lr. Requires a validation_data to be passed in the keras::fit() call if set to something else than Inf.

factor

An numeric vector of lenght one which will scale max_lr and (if applicable according to decrease_base_lr) base_lr after patience epochs without improvement in the validation loss.

decrease_base_lr

Boolean indicating whether base_lr should also be scaled with factor or not.

cooldown

Number of epochs to wait before resuming normal operation after learning rate has been reduced.

verbose

Currently supporting 0 (silent) and 1 (verbose).

Details

The amplitude of the cycle can be scaled on a per-iteration or per-cycle basis. This class has three built-in policies, as put forth in the paper.

For more details, please see paper.

Differences to Python implementation

This implementation differs from the Python implementation in the following aspects:

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
library(keras)
dataset <- dataset_boston_housing()
c(c(train_data, train_targets), c(test_data, test_targets)) %<-% dataset

mean <- apply(train_data, 2, mean)
std <- apply(train_data, 2, sd)
train_data <- scale(train_data, center = mean, scale = std)
test_data <- scale(test_data, center = mean, scale = std)


model <- keras_model_sequential() %>%
  layer_dense(
    units = 64, activation = "relu",
    input_shape = dim(train_data)[[2]]
  ) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dense(units = 1)
model %>% compile(
  optimizer = optimizer_rmsprop(lr = 0.001),
  loss = "mse",
  metrics = c("mae")
)

callback_clr <- new_callback_cyclical_learning_rate(
  step_size = 32,
  base_lr = 0.001,
  max_lr = 0.006,
  gamma = 0.99,
  mode = "exp_range"
)
model %>% fit(
  train_data, train_targets,
  validation_data = list(test_data, test_targets),
  epochs = 10, verbose = 1,
  callbacks = list(callback_clr)
)
callback_clr$history
plot_clr_history(callback_clr, backend = "base")

lorenzwalthert/KerasMisc documentation built on May 7, 2021, 6:31 a.m.