Siamese Networks are neural networks which share weights between two or more sister networks, each producing embedding vectors of its respective inputs.
In supervised similarity learning, the networks are then trained to maximize the contrast (distance) between embeddings of inputs of different classes, while minimizing the distance between embeddings of similar classes, resulting in embedding spaces that reflect the class segmentation of the training inputs.
This implementation loosely follows Hadsell-et-al.'06 [1] (see paper for mode details) but the euclidean distance is replaced by a subtraction layer and one fully-connect (FC) layer.
[1] "Dimensionality Reduction by Learning an Invariant Mapping" https://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Gets to 98.11% test accuracy after 20 epochs. 3 seconds per epoch on a AMD Ryzen 7 PRO 4750U (CPU)
library(keras3)
contrastive_loss <- function(y_true, y_pred) { # Contrastive loss from Hadsell-et-al.'06 # https://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf margin = 1 margin_square = op_square(op_maximum(margin - (y_pred), 0)) op_mean((1 - y_true) * op_square(y_pred) + (y_true) * margin_square) }
We will train the model to differentiate between digits of different classes. For
example, digit 0
needs to be differentiated from the rest of the
digits (1
through 9
), digit 1
- from 0
and 2
through 9
, and so on.
To carry this out, we will select N random images from class A (for example,
for digit 0
) and pair them with N random images from another class B
(for example, for digit 1
). Then, we can repeat this process for all classes
of digits (until digit 9
). Once we have paired digit 0
with other digits,
we can repeat this process for the remaining classes for the rest of the digits
(from 1
until 9
).
create_pairs <- function(x, y) { # Positive and negative pair creation. # Alternates between positive and negative pairs. digit_indices <- tapply(1:length(y), y, list) y1 <- y y2 <- sapply(y, function(a) sample(0:9,1,prob=0.1+0.8*(0:9==a))) idx1 <- 1:nrow(x) idx2 <- sapply(as.character(y2), function(a) sample(digit_indices[[a]],1)) is_same <- 1*(y1==y2) list(pair1 = x[idx1,], pair2 = x[idx2,], y = is_same) } compute_accuracy <- function(predictions, labels) { # Compute classification accuracy with a fixed threshold on distances. mean(labels[predictions > 0.5]) }
# the data, shuffled and split between train and test sets mnist <- dataset_mnist() x_train <- mnist$train$x y_train <- mnist$train$y x_test <- mnist$test$x y_test <- mnist$test$y x_train <- array_reshape(x_train, c(nrow(x_train), 784)) x_test <- array_reshape(x_test, c(nrow(x_test), 784)) x_train <- x_train / 255 x_test <- x_test / 255 # create training+test positive and negative pairs tr <- create_pairs(x_train, y_train) te <- create_pairs(x_test, y_test) names(tr)
## [1] "pair1" "pair2" "y"
# input layers input_dim = 784 input_1 <- layer_input(shape = c(input_dim)) input_2 <- layer_input(shape = c(input_dim)) # definition of the base network that will be shared base_network <- keras_model_sequential() %>% layer_dense(units = 128, activation = 'relu') %>% layer_dropout(rate = 0.1) %>% layer_dense(units = 128, activation = 'relu') %>% layer_dropout(rate = 0.1) %>% layer_dense(units = 128, activation = 'relu') # because we re-use the same instance `base_network`, the weights of # the network will be shared across the two branches branch_1 <- base_network(input_1) branch_2 <- base_network(input_2) # merging layer out <- layer_subtract(list(branch_1, branch_2)) %>% layer_dropout(rate = 0.1) %>% layer_dense(units = 16, activation = 'relu') %>% layer_dense(1, activation = "sigmoid") # create and compile model model <- keras_model(list(input_1, input_2), out)
model %>% compile( optimizer = "rmsprop", #loss = "binary_crossentropy", loss = contrastive_loss, metrics = metric_binary_accuracy ) history <- model %>% fit( list(tr$pair1, tr$pair2), tr$y, batch_size = 128, epochs = 20, validation_data = list( list(te$pair1, te$pair2), te$y ) )
## Epoch 1/20 ## 469/469 - 17s - 36ms/step - binary_accuracy: 0.7566 - loss: 0.1644 - val_binary_accuracy: 0.8753 - val_loss: 0.1001 ## Epoch 2/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.8896 - loss: 0.0867 - val_binary_accuracy: 0.9248 - val_loss: 0.0614 ## Epoch 3/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9261 - loss: 0.0579 - val_binary_accuracy: 0.9409 - val_loss: 0.0461 ## Epoch 4/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9474 - loss: 0.0419 - val_binary_accuracy: 0.9529 - val_loss: 0.0382 ## Epoch 5/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9589 - loss: 0.0328 - val_binary_accuracy: 0.9603 - val_loss: 0.0316 ## Epoch 6/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9660 - loss: 0.0272 - val_binary_accuracy: 0.9617 - val_loss: 0.0297 ## Epoch 7/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9708 - loss: 0.0234 - val_binary_accuracy: 0.9665 - val_loss: 0.0270 ## Epoch 8/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9747 - loss: 0.0201 - val_binary_accuracy: 0.9674 - val_loss: 0.0259 ## Epoch 9/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9787 - loss: 0.0175 - val_binary_accuracy: 0.9697 - val_loss: 0.0247 ## Epoch 10/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9807 - loss: 0.0157 - val_binary_accuracy: 0.9684 - val_loss: 0.0251 ## Epoch 11/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9810 - loss: 0.0152 - val_binary_accuracy: 0.9710 - val_loss: 0.0230 ## Epoch 12/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9831 - loss: 0.0138 - val_binary_accuracy: 0.9726 - val_loss: 0.0224 ## Epoch 13/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9854 - loss: 0.0121 - val_binary_accuracy: 0.9724 - val_loss: 0.0224 ## Epoch 14/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9861 - loss: 0.0114 - val_binary_accuracy: 0.9738 - val_loss: 0.0217 ## Epoch 15/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9864 - loss: 0.0110 - val_binary_accuracy: 0.9738 - val_loss: 0.0213 ## Epoch 16/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9881 - loss: 0.0101 - val_binary_accuracy: 0.9754 - val_loss: 0.0206 ## Epoch 17/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9883 - loss: 0.0097 - val_binary_accuracy: 0.9733 - val_loss: 0.0214 ## Epoch 18/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9881 - loss: 0.0097 - val_binary_accuracy: 0.9723 - val_loss: 0.0216 ## Epoch 19/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9891 - loss: 0.0090 - val_binary_accuracy: 0.9749 - val_loss: 0.0202 ## Epoch 20/20 ## 469/469 - 1s - 1ms/step - binary_accuracy: 0.9898 - loss: 0.0085 - val_binary_accuracy: 0.9752 - val_loss: 0.0205
plot(history)
# compute final accuracy on training and test sets tr_pred <- predict(model, list(tr$pair1, tr$pair2))[,1]
## 1875/1875 - 1s - 793us/step
tr_acc <- compute_accuracy(tr_pred, tr$y) te_pred <- predict(model, list(te$pair1, te$pair2))[,1]
## 313/313 - 0s - 1ms/step
te_acc <- compute_accuracy(te_pred, te$y) sprintf('* Accuracy on training set: %0.2f%%', (100 * tr_acc))
## [1] "* Accuracy on training set: 99.63%"
sprintf('* Accuracy on test set: %0.2f%%', (100 * te_acc))
## [1] "* Accuracy on test set: 97.93%"
par(mfrow=c(1,1)) vioplot::vioplot( te_pred ~ te$y )
i=3 visualizePair <- function(i) { image(rbind(matrix( te$pair1[i,],28,28)[,28:1], matrix( te$pair2[i,],28,28)[,28:1])) title(paste("true:", te$y[i],"| pred:", round(te_pred[i],5))) } par(mfrow=c(3,3)) lapply(1:9, visualizePair)
## [[1]] ## NULL ## ## [[2]] ## NULL ## ## [[3]] ## NULL ## ## [[4]] ## NULL ## ## [[5]] ## NULL ## ## [[6]] ## NULL ## ## [[7]] ## NULL ## ## [[8]] ## NULL ## ## [[9]] ## NULL
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.