# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(luz) # Datasets and loaders ---------------------------------------------------- dir <- "./mnist" # caching directory # Modify the MNIST dataset so the target is identical to the input. mnist_dataset2 <- torch::dataset( inherit = mnist_dataset, .getitem = function(i) { output <- super$.getitem(i) output$y <- output$x output } ) train_ds <- mnist_dataset2( dir, download = TRUE, transform = transform_to_tensor ) test_ds <- mnist_dataset2( dir, train = FALSE, transform = transform_to_tensor ) train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE) test_dl <- dataloader(test_ds, batch_size = 128) # Building the network --------------------------------------------------- net <- nn_module( "Net", initialize = function() { self$encoder <- nn_sequential( nn_conv2d(1, 6, kernel_size=5), nn_relu(), nn_conv2d(6, 16, kernel_size=5), nn_relu() ) self$decoder <- nn_sequential( nn_conv_transpose2d(16, 6, kernel_size = 5), nn_relu(), nn_conv_transpose2d(6, 1, kernel_size = 5), nn_sigmoid() ) }, forward = function(x) { x %>% self$encoder() %>% self$decoder() }, predict = function(x) { self$encoder(x) %>% torch_flatten(start_dim = 2) } ) # Train ------------------------------------------------------------------- fitted <- net %>% setup( loss = nn_mse_loss(), optimizer = optim_adam ) %>% fit(train_dl, epochs = 1, valid_data = test_dl) # Create predictions ------------------------------------------------------ preds <- predict(fitted, test_dl) # Serialize --------------------------------------------------------------- luz_save(fitted, "mnist-autoencoder.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.