knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.path = "man/figures/README-", out.width = "100%" )
The goal of KerasMisc is to provide a collection of tools that enhance the R implementation of Keras. Currently, the package features:
Contributions welcome.
You can install the development version of KerasMisc from GitHub with
remotes::install_github("lorenzwalthert/KerasMisc")
Keras callbacks
Let's create a model
library(keras) library(KerasMisc) dataset <- dataset_boston_housing() c(c(train_data, train_targets), c(test_data, test_targets)) %<-% dataset mean <- apply(train_data, 2, mean) std <- apply(train_data, 2, sd) train_data <- scale(train_data, center = mean, scale = std) test_data <- scale(test_data, center = mean, scale = std) model <- keras_model_sequential() %>% layer_dense( units = 64, activation = "relu", input_shape = dim(train_data)[[2]] ) %>% layer_dense(units = 64, activation = "relu") %>% layer_dense(units = 1) model %>% compile( optimizer = optimizer_rmsprop(lr = 0.001), loss = "mse", metrics = c("mae") )
Next, we can fit the model with a learning rate schedule. We dynamically adjust
the bandwidths of the learnin rate (multiplication with 0.9) whenever the
validation loss does not decrease for three epochs. When decreased, we wait 2
epochs (cooldown
) before we set in the patience counter again.
iter_per_epoch <- nrow(train_data) / 32 callback_clr <- new_callback_cyclical_learning_rate( step_size = iter_per_epoch * 2, base_lr = 0.001, max_lr = 0.006, mode = "triangular", patience = 3, factor = 0.9, cooldown = 2, verbose = 0 ) model %>% fit( train_data, train_targets, validation_data = list(test_data, test_targets), epochs = 50, verbose = 0, callbacks = list(callback_clr) )
We can now have a look at the learning rates:
head(callback_clr$history)
backend <- ifelse(rlang::is_installed("ggplot2"), "ggplot2", "base") plot_clr_history(callback_clr, granularity = "iteration", backend = backend)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.