README.md

R-CMD-check

lifecycle Travis build
status Coverage
status

KerasMisc

The goal of KerasMisc is to provide a collection of tools that enhance the R implementation of Keras. Currently, the package features:

Contributions welcome.

Installation

You can install the development version of KerasMisc from GitHub with

remotes::install_github("lorenzwalthert/KerasMisc")

Features

Keras callbacks

Let’s create a model

library(keras)
library(KerasMisc)
dataset <- dataset_boston_housing()
c(c(train_data, train_targets), c(test_data, test_targets)) %<-% dataset

mean <- apply(train_data, 2, mean)
std <- apply(train_data, 2, sd)
train_data <- scale(train_data, center = mean, scale = std)
test_data <- scale(test_data, center = mean, scale = std)


model <- keras_model_sequential() %>%
  layer_dense(
    units = 64, activation = "relu",
    input_shape = dim(train_data)[[2]]
  ) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dense(units = 1)
model %>% compile(
  optimizer = optimizer_rmsprop(lr = 0.001),
  loss = "mse",
  metrics = c("mae")
)

Next, we can fit the model with a learning rate schedule. We dynamically adjust the bandwidths of the learnin rate (multiplication with 0.9) whenever the validation loss does not decrease for three epochs. When decreased, we wait 2 epochs (cooldown) before we set in the patience counter again.

iter_per_epoch <- nrow(train_data) / 32
callback_clr <- new_callback_cyclical_learning_rate(
  step_size = iter_per_epoch * 2,
  base_lr = 0.001,
  max_lr = 0.006,
  mode = "triangular", 
  patience = 3, 
  factor = 0.9, 
  cooldown = 2,
  verbose = 0
)
model %>% fit(
  train_data, train_targets,
  validation_data = list(test_data, test_targets),
  epochs = 50, verbose = 0,
  callbacks = list(callback_clr)
)

We can now have a look at the learning rates:

head(callback_clr$history)
#>            lr base_lr max_lr iteration epochs
#> 1 0.001000000   0.001  0.006         0      1
#> 2 0.001198020   0.001  0.006         1      1
#> 3 0.001396040   0.001  0.006         2      1
#> 4 0.001594059   0.001  0.006         3      1
#> 5 0.001792079   0.001  0.006         4      1
#> 6 0.001990099   0.001  0.006         5      1
backend <- ifelse(rlang::is_installed("ggplot2"), "ggplot2", "base") 
plot_clr_history(callback_clr, granularity = "iteration", backend = backend)



lorenzwalthert/KerasMisc documentation built on May 7, 2021, 6:31 a.m.