# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)

# Datasets and loaders ----------------------------------------------------

dir <- "./mnist" # caching directory

triplet_mnist_dataset <- dataset(
  inherit = mnist_dataset,
  .getitem = function(index) {
    anchor <- self$data[index, ,]
    label <- self$targets[index]

    negative <- self$data[sample(which(self$targets != label), 1),,]
    positive <- self$data[sample(which(self$targets == label), 1),,]

    list(
      list( # input is a list with 3 images.
        anchor = self$transform(anchor),
        negative = self$transform(negative),
        positive = self$transform(positive)
      ),
      list() # no 'target'
    )
  }
)

train_ds <- triplet_mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)

test_ds <- triplet_mnist_dataset(
  dir,
  train = FALSE,
  transform = transform_to_tensor
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)

# Define the network ------------------------------------------------------

net <- nn_module(
  "Net",
  initialize = function(embedding_dim) {
    self$conv1 <- nn_conv2d(1, 32, 3, 1)
    self$conv2 <- nn_conv2d(32, 64, 3, 1)
    self$dropout1 <- nn_dropout(0.25)
    self$dropout2 <- nn_dropout(0.5)
    self$fc1 <- nn_linear(9216, 512)
    self$fc2 <- nn_linear(512, embedding_dim)
  },
  forward = function(x) {
    x <- self$conv1(x)
    x <- nnf_relu(x)
    x <- self$conv2(x)
    x <- nnf_relu(x)
    x <- nnf_max_pool2d(x, 2)
    x <- self$dropout1(x)
    x <- torch_flatten(x, start_dim = 2)
    x <- self$fc1(x)
    x <- nnf_relu(x)
    x <- self$dropout2(x)
    x <- self$fc2(x)
    x
  }
)

triplet_model <- torch::nn_module(
  initialize = function(embedding_dim = 2, margin = 1) {
    self$embedding <- net(embedding_dim = embedding_dim)
    self$criterion <- nn_triplet_margin_loss(margin = margin)
  },
  loss = function(input, ...) {
    embeds <- lapply(input, self$embedding)
    self$criterion(
      embeds$anchor,
      embeds$positive,
      embeds$negative
    )
  }
)

fitted <- triplet_model %>%
  setup(optimizer = optim_adam) %>%
  set_hparams(embedding_dim = 2) %>%
  fit(train_dl, epochs = 10, valid_data = test_dl)

# Serializing

luz_save(fitted, "triplet.pt")


mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.