# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(luz) # Datasets and loaders ---------------------------------------------------- dir <- "./mnist" # caching directory train_ds <- mnist_dataset( dir, download = TRUE, transform = transform_to_tensor ) test_ds <- mnist_dataset( dir, train = FALSE, transform = transform_to_tensor ) train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE) test_dl <- dataloader(test_ds, batch_size = 128) # Building the network --------------------------------------------------- net <- nn_module( "Net", initialize = function(accumulate_batches = 2) { self$conv1 <- nn_conv2d(1, 32, 3, 1) self$conv2 <- nn_conv2d(32, 64, 3, 1) self$dropout1 <- nn_dropout(0.25) self$dropout2 <- nn_dropout(0.5) self$fc1 <- nn_linear(9216, 128) self$fc2 <- nn_linear(128, 10) self$accumulate_batches <- accumulate_batches }, forward = function(x) { x <- self$conv1(x) x <- nnf_relu(x) x <- self$conv2(x) x <- nnf_relu(x) x <- nnf_max_pool2d(x, 2) x <- self$dropout1(x) x <- torch_flatten(x, start_dim = 2) x <- self$fc1(x) x <- nnf_relu(x) x <- self$dropout2(x) x <- self$fc2(x) x }, step = function() { # we implement a custom step method that runs for every # batch in training and validation. # calculate predictions. we save them in `ctx$pred` so other parts of luz # can use it. ctx$pred <- ctx$model(ctx$input) # we now calculate the loss. also save it in `ctx$loss` so, for example, # it's correctly logged. ctx$loss <- ctx$loss_fn(ctx$pred, ctx$target) # `ctx$training` is set automatically to `TRUE` during the training phase if (ctx$training) { ctx$loss <- ctx$loss/self$accumulate_batches ctx$loss$backward() } # only after `accumulate_batches` that we do a optimizer step, so we use # the virtual batch_size. if (ctx$training && (ctx$iter %% self$accumulate_batches == 0)) { opt <- ctx$optimizers[[1]] opt$step() opt$zero_grad() } } ) # Train ------------------------------------------------------------------- fitted <- net %>% set_hparams(accumulate_batches = 10) %>% setup( loss = nn_cross_entropy_loss(), optimizer = torch::optim_adam, metrics = list( luz_metric_accuracy() ) ) %>% fit(train_dl, valid_data = test_dl, epochs = 10) # Serialization ----------------------------------------------------------- luz_save(fitted, "mnist-virtual-batch_size.pt")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.