# Packages ---------------------------------------------------------------- library(torch) library(torchvision) library(torchdatasets) library(luz) # Datasets and loaders ---------------------------------------------------- dir <- "./pets" #caching directory # A light wrapper around the `oxford_pet_dataset` that resizes and transforms # input images and masks to the specified `size` and introduces the `augmentation` # argument, allowing us to specify transformations that must be synced between # images and masks, eg. flipping, cropping, etc. pet_dataset <- torch::dataset( inherit = oxford_pet_dataset, initialize = function(..., augmentation = NULL, size = c(224, 224)) { input_transform <- function(x) { x %>% transform_to_tensor() %>% transform_resize(size) } target_transform <- function(x) { x <- torch_tensor(x, dtype = torch_long()) x <- x[newaxis,..] x <- transform_resize(x, size, interpolation = 0) x[1,..] } self$split <- split super$initialize( ..., transform = input_transform, target_transform = target_transform ) if (is.null(augmentation)) self$augmentation <- function(...) {list(...)} else self$augmentation <- augmentation }, .getitem = function(i) { items <- super$.getitem(i) do.call(self$augmentation, items) } ) train_ds <- pet_dataset( dir, download = TRUE, split = "train" ) valid_ds <- pet_dataset( dir, download = TRUE, split = "valid" ) train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE) valid_dl <- dataloader(valid_ds, batch_size = 32) # Define the network ------------------------------------------------------ # We use a pre-trained mobile net encoder. We take intermediate layers to use # in the skip connections. encoder <- torch::nn_module( initialize = function() { model <- model_mobilenet_v2(pretrained = TRUE) self$stages <- nn_module_list(list( nn_identity(), model$features[1:2], model$features[3:4], model$features[5:7], model$features[8:14], model$features[15:18] )) for (par in self$parameters) { par$requires_grad_(FALSE) } }, forward = function(x) { features <- list() for (i in 1:length(self$stages)) { x <- self$stages[[i]](x) features[[length(features) + 1]] <- x } features } ) # The decoder blocks are composed of a upsample layer + a convolution # with same padding. decoder_block <- nn_module( initialize = function(in_channels, skip_channels, out_channels) { self$upsample <- nn_conv_transpose2d( in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2 ) self$activation <- nn_relu() self$conv <- nn_conv2d( in_channels = out_channels + skip_channels, out_channels = out_channels, kernel_size = 3, padding = "same" ) }, forward = function(x, skip) { x <- x %>% self$upsample() %>% self$activation() input <- torch_cat(list(x, skip), dim = 2) input %>% self$conv() %>% self$activation() } ) # We build the decoder by making a sequence of `decoder_blocks` matching # the sizes to be compatible with the encoder sizes. decoder <- nn_module( initialize = function( decoder_channels = c(256, 128, 64, 32, 16), encoder_channels = c(16, 24, 32, 96, 320) ) { encoder_channels <- rev(encoder_channels) skip_channels <- c(encoder_channels[-1], 3) in_channels <- c(encoder_channels[1], decoder_channels) depth <- length(encoder_channels) self$blocks <- nn_module_list() for (i in seq_len(depth)) { self$blocks$append(decoder_block( in_channels = in_channels[i], skip_channels = skip_channels[i], out_channels = decoder_channels[i] )) } }, forward = function(features) { features <- rev(features) x <- features[[1]] for (i in seq_along(self$blocks)) { x <- self$blocks[[i]](x, features[[i+1]]) } x } ) # FInally the model is the composition of encoder and decoder + an output # layer that will produce the distribution for each one of the possible # classes. model <- nn_module( initialize = function() { self$encoder <- encoder() self$decoder <- decoder() self$output <- nn_sequential( nn_conv2d(16, 3, 3, padding = "same") ) }, forward = function(x) { x %>% self$encoder() %>% self$decoder() %>% self$output() } ) # Train --------------------------------------------- # We train using the cross entropy loss. We could have used the dice loss # too, but it's harder to optimize. model <- model %>% setup(optimizer = optim_adam, loss = nn_cross_entropy_loss()) f <- lr_finder(model, train_dl) plot(f) fitted <- model %>% set_opt_hparams(lr = 1e-3) %>% fit(train_dl, epochs = 10, valid_data = valid_dl) plot(fitted) # Plot validation image --------------------- library(raster) preds <- predict(fitted, dataloader(dataset_subset(valid_ds, 2))) mask <- as.array(torch_argmax(preds[1,..], 1)$to(device = "cpu")) mask <- raster::ratify(raster::raster(mask)) img <- raster::brick(as.array(valid_ds[2][[1]]$permute(c(2,3,1)))) raster::plotRGB(img, scale = 1) plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.