# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(luz) # Datasets and loaders ---------------------------------------------------- dir <- "./mnist" # caching directory triplet_mnist_dataset <- dataset( inherit = mnist_dataset, .getitem = function(index) { anchor <- self$data[index, ,] label <- self$targets[index] negative <- self$data[sample(which(self$targets != label), 1),,] positive <- self$data[sample(which(self$targets == label), 1),,] list( list( # input is a list with 3 images. anchor = self$transform(anchor), negative = self$transform(negative), positive = self$transform(positive) ), list() # no 'target' ) } ) train_ds <- triplet_mnist_dataset( dir, download = TRUE, transform = transform_to_tensor ) test_ds <- triplet_mnist_dataset( dir, train = FALSE, transform = transform_to_tensor ) train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE) test_dl <- dataloader(test_ds, batch_size = 128) # Define the network ------------------------------------------------------ net <- nn_module( "Net", initialize = function(embedding_dim) { self$conv1 <- nn_conv2d(1, 32, 3, 1) self$conv2 <- nn_conv2d(32, 64, 3, 1) self$dropout1 <- nn_dropout(0.25) self$dropout2 <- nn_dropout(0.5) self$fc1 <- nn_linear(9216, 512) self$fc2 <- nn_linear(512, embedding_dim) }, forward = function(x) { x <- self$conv1(x) x <- nnf_relu(x) x <- self$conv2(x) x <- nnf_relu(x) x <- nnf_max_pool2d(x, 2) x <- self$dropout1(x) x <- torch_flatten(x, start_dim = 2) x <- self$fc1(x) x <- nnf_relu(x) x <- self$dropout2(x) x <- self$fc2(x) x } ) triplet_model <- torch::nn_module( initialize = function(embedding_dim = 2, margin = 1) { self$embedding <- net(embedding_dim = embedding_dim) self$criterion <- nn_triplet_margin_loss(margin = margin) }, loss = function(input, ...) { embeds <- lapply(input, self$embedding) self$criterion( embeds$anchor, embeds$positive, embeds$negative ) } ) fitted <- triplet_model %>% setup(optimizer = optim_adam) %>% set_hparams(embedding_dim = 2) %>% fit(train_dl, epochs = 10, valid_data = test_dl) # Serializing luz_save(fitted, "triplet.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.