# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(luz) set.seed(1) torch_manual_seed(1) # Datasets and loaders ---------------------------------------------------- dir <- "./mnist" # caching directory train_ds <- mnist_dataset( dir, download = TRUE, transform = transform_to_tensor ) test_ds <- mnist_dataset( dir, train = FALSE, transform = transform_to_tensor ) train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE) test_dl <- dataloader(test_ds, batch_size = 128) # Building the network --------------------------------------------------- net <- nn_module( "Net", initialize = function() { self$conv1 <- nn_conv2d(1, 32, 3, 1) self$conv2 <- nn_conv2d(32, 64, 3, 1) self$dropout1 <- nn_dropout(0.25) self$dropout2 <- nn_dropout(0.5) self$fc1 <- nn_linear(9216, 128) self$fc2 <- nn_linear(128, 10) }, forward = function(x) { x %>% self$conv1() %>% nnf_relu() %>% self$conv2() %>% nnf_relu() %>% nnf_max_pool2d(2) %>% self$dropout1() %>% torch_flatten(start_dim = 2) %>% self$fc1() %>% nnf_relu() %>% self$dropout2() %>% self$fc2() } ) # Train ------------------------------------------------------------------- fitted <- net %>% setup( # we must use a un-reduced loss to allow mixing up loss = nn_cross_entropy_loss(reduction = "none"), optimizer = optim_adam, metrics = luz_metric_set( valid_metrics = luz_metric_accuracy() ) ) %>% fit( train_dl, epochs = 10, valid_data = test_dl, callbacks = list( luz_callback_mixup( alpha = 0.4, auto_loss = TRUE # loss is automatically modified by the mixup callback ) ) ) # Making predictions ------------------------------------------------------ preds <- predict(fitted, test_dl) preds$shape # Serialization ----------------------------------------------------------- luz_save(fitted, "mnist-mixup.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.