R-CMD-check

knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.path = "man/figures/README-",
  out.width = "100%"
)

lifecycle Travis build status Coverage status

KerasMisc

The goal of KerasMisc is to provide a collection of tools that enhance the R implementation of Keras. Currently, the package features:

Contributions welcome.

Installation

You can install the development version of KerasMisc from GitHub with

remotes::install_github("lorenzwalthert/KerasMisc")

Features

Keras callbacks

Let's create a model

library(keras)
library(KerasMisc)
dataset <- dataset_boston_housing()
c(c(train_data, train_targets), c(test_data, test_targets)) %<-% dataset

mean <- apply(train_data, 2, mean)
std <- apply(train_data, 2, sd)
train_data <- scale(train_data, center = mean, scale = std)
test_data <- scale(test_data, center = mean, scale = std)


model <- keras_model_sequential() %>%
  layer_dense(
    units = 64, activation = "relu",
    input_shape = dim(train_data)[[2]]
  ) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dense(units = 1)
model %>% compile(
  optimizer = optimizer_rmsprop(lr = 0.001),
  loss = "mse",
  metrics = c("mae")
)

Next, we can fit the model with a learning rate schedule. We dynamically adjust the bandwidths of the learnin rate (multiplication with 0.9) whenever the validation loss does not decrease for three epochs. When decreased, we wait 2 epochs (cooldown) before we set in the patience counter again.

iter_per_epoch <- nrow(train_data) / 32
callback_clr <- new_callback_cyclical_learning_rate(
  step_size = iter_per_epoch * 2,
  base_lr = 0.001,
  max_lr = 0.006,
  mode = "triangular", 
  patience = 3, 
  factor = 0.9, 
  cooldown = 2,
  verbose = 0
)
model %>% fit(
  train_data, train_targets,
  validation_data = list(test_data, test_targets),
  epochs = 50, verbose = 0,
  callbacks = list(callback_clr)
)

We can now have a look at the learning rates:

head(callback_clr$history)
backend <- ifelse(rlang::is_installed("ggplot2"), "ggplot2", "base") 
plot_clr_history(callback_clr, granularity = "iteration", backend = backend)


lorenzwalthert/KerasMisc documentation built on May 7, 2021, 6:31 a.m.