knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(keras)
library(capsnet)



mnist <- dataset_mnist()

mnist$train$x <- array_reshape(mnist$train$x, dim = c(60000, 28, 28, 1))
mnist$test$x <- array_reshape(mnist$test$x, dim = c(10000, 28, 28, 1))

mnist$train$y <- to_categorical(mnist$train$y, 10)
mnist$test$y <- to_categorical(mnist$test$y, 10)

model <- create_capsnet(input_shape = c(28, 28, 1), n_class = 10L, num_routing = 3L)

margin_loss <- function(y_true, y_pred) {
  K <- keras::backend()
  L <- y_true * K$square(K$maximum(0, 0.9 - y_pred)) +
    0.5 * (1 - y_true) * K$square(K$maximum(0, y_pred - 0.1))

  K$mean(K$sum(L, 1L))
}



lr_decay <- callback_learning_rate_scheduler(schedule = function(epoch) {
  0.001*(0.9^epoch)
})


# compile the model
model$train_model %>% compile(
  optimizer=optimizer_adam(),
  loss=list(margin_loss, "mse"),
  loss_weights=list(1.0, 0.392),
  metrics=list('capsnet' = 'accuracy')
)


model$train_model %>%
  fit(
    x = list(mnist$train$x, mnist$train$y),
    y = list(mnist$train$y, mnist$train$x),
    batch_size = 100,
    epochs = 50,
    validation_data = list(
      list(mnist$test$x, mnist$test$y),
      list(mnist$test$y, mnist$test$x)
    )#,
    #callbacks = list(lr_decay)
  )


dfalbel/capsnet documentation built on May 29, 2019, 12:37 p.m.