In this example we will train a DenseNet-100-12 to classify images from the CIFAR10 small images dataset. Because of GPU memmory limitations we had to use a batch size of 32 instead of 64 from the original paper. This takes ~660s per epoch on a NVIDIA GEFORCE 1080 Ti Final accuracy on test set was 0.9410 versus 0.9433 reported on the paper.

# Libraries ---------------------------------------------------------------
library(keras)
library(densenet)

# Parameters --------------------------------------------------------------

batch_size <- 32
epochs <- 350

# Data Preparation --------------------------------------------------------

# see ?dataset_cifar10 for more info
cifar10 <- dataset_cifar10()

# Normalisation
for(i in 1:3){
  mea <- mean(cifar10$train$x[,,,i])
  sds <- sd(cifar10$train$x[,,,i])

  cifar10$train$x[,,,i] <- (cifar10$train$x[,,,i] - mea) / sds
  cifar10$test$x[,,,i] <- (cifar10$test$x[,,,i] - mea) / sds
}
x_train <- cifar10$train$x
x_test <- cifar10$test$x

y_train <- to_categorical(cifar10$train$y, num_classes = 10)
y_test <- to_categorical(cifar10$test$y, num_classes = 10)

# Model Definition -------------------------------------------------------

input_img <- layer_input(shape = c(32, 32, 3))
model <- application_densenet(include_top = TRUE, input_tensor = input_img, 
                              dropout_rate = 0.2, depth = 100)

opt <- optimizer_sgd(lr = 0.1, momentum = 0.9, nesterov = TRUE)

model %>% compile(
  optimizer = opt,
  loss = "categorical_crossentropy",
  metrics = "accuracy"
)

# Model fitting -----------------------------------------------------------

# callbacks for weights and learning rate
lr_schedule <- function(epoch) {

  if(epoch <= 0.5*epochs) {
    return(0.1)
  } else if(epoch > 0.5*epochs & epoch <= 0.75*epochs){
    return(0.01)
  } else {
    return(0.001)
  }

}

lr_reducer <- callback_learning_rate_scheduler(lr_schedule)
model_checkpoint <- callback_model_checkpoint("inst/weights/cifar10-DenseNet-100-12.hdf5", monitor = "val_acc")

history <- model %>% fit(
  x_train, y_train, 
  batch_size = batch_size, 
  epochs = epochs, 
  validation_data = list(x_test, y_test), 
  callbacks = list(
    lr_reducer,
    model_checkpoint
  )
)

plot(history)

evaluate(model, x_test, y_test)


dfalbel/densenet documentation built on May 31, 2019, 11:49 p.m.