The goal of KerasMisc is to provide a collection of tools that enhance the R implementation of Keras. Currently, the package features:
Contributions welcome.
You can install the development version of KerasMisc from GitHub with
remotes::install_github("lorenzwalthert/KerasMisc")
Keras callbacks
Let’s create a model
library(keras)
library(KerasMisc)
dataset <- dataset_boston_housing()
c(c(train_data, train_targets), c(test_data, test_targets)) %<-% dataset
mean <- apply(train_data, 2, mean)
std <- apply(train_data, 2, sd)
train_data <- scale(train_data, center = mean, scale = std)
test_data <- scale(test_data, center = mean, scale = std)
model <- keras_model_sequential() %>%
layer_dense(
units = 64, activation = "relu",
input_shape = dim(train_data)[[2]]
) %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 1)
model %>% compile(
optimizer = optimizer_rmsprop(lr = 0.001),
loss = "mse",
metrics = c("mae")
)
Next, we can fit the model with a learning rate schedule. We dynamically
adjust the bandwidths of the learnin rate (multiplication with 0.9)
whenever the validation loss does not decrease for three epochs. When
decreased, we wait 2 epochs (cooldown
) before we set in the patience
counter again.
iter_per_epoch <- nrow(train_data) / 32
callback_clr <- new_callback_cyclical_learning_rate(
step_size = iter_per_epoch * 2,
base_lr = 0.001,
max_lr = 0.006,
mode = "triangular",
patience = 3,
factor = 0.9,
cooldown = 2,
verbose = 0
)
model %>% fit(
train_data, train_targets,
validation_data = list(test_data, test_targets),
epochs = 50, verbose = 0,
callbacks = list(callback_clr)
)
We can now have a look at the learning rates:
head(callback_clr$history)
#> lr base_lr max_lr iteration epochs
#> 1 0.001000000 0.001 0.006 0 1
#> 2 0.001198020 0.001 0.006 1 1
#> 3 0.001396040 0.001 0.006 2 1
#> 4 0.001594059 0.001 0.006 3 1
#> 5 0.001792079 0.001 0.006 4 1
#> 6 0.001990099 0.001 0.006 5 1
backend <- ifelse(rlang::is_installed("ggplot2"), "ggplot2", "base")
plot_clr_history(callback_clr, granularity = "iteration", backend = backend)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.