# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(luz) set.seed(1) torch_manual_seed(1) # Datasets and loaders ---------------------------------------------------- dir <- "./mnist" # caching directory train_ds <- mnist_dataset( dir, download = TRUE, transform = transform_to_tensor ) test_ds <- mnist_dataset( dir, train = FALSE, transform = transform_to_tensor ) train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE) test_dl <- dataloader(test_ds, batch_size = 128) # Building the network --------------------------------------------------- net <- nn_module( "Net", initialize = function() { self$conv1 <- nn_conv2d(1, 32, 3, 1) self$conv2 <- nn_conv2d(32, 64, 3, 1) self$dropout1 <- nn_dropout(0.25) self$dropout2 <- nn_dropout(0.5) self$fc1 <- nn_linear(9216, 128) self$fc2 <- nn_linear(128, 10) }, forward = function(x) { x %>% self$conv1() %>% nnf_relu() %>% self$conv2() %>% nnf_relu() %>% nnf_max_pool2d(2) %>% self$dropout1() %>% torch_flatten(start_dim = 2) %>% self$fc1() %>% nnf_relu() %>% self$dropout2() %>% self$fc2() } ) # Train ------------------------------------------------------------------- fitted <- net %>% setup( loss = nn_cross_entropy_loss(), optimizer = optim_adam, metrics = list( luz_metric_accuracy() ) ) %>% fit(train_dl, epochs = 10, valid_data = test_dl) # Making predictions ------------------------------------------------------ preds <- predict(fitted, test_dl) preds$shape # Serialization ----------------------------------------------------------- luz_save(fitted, "mnist-cnn.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.