knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
#install.packages(c("RGAN", "torchvision"))
library(RGAN)
library(torch)

Replication of the Example in Section 4.1

The first example with very simple tabular data should serve as an easy entry point to working with the RGAN package

data <- RGAN::sample_toydata()

#' # Build new transformer to standardize data
transformer <- RGAN::data_transformer$new()

#' # Fit transformer to data
transformer$fit(data)

#' # Transform data and store as new object
transformed_data <-  transformer$transform(data)
# Plot original data and transformed data (Figure 1)

par(mfrow = c(1, 2))
dimensions <- c(1, 2)

plot(
  data[, dimensions],
  bty = "n",
  col = viridis::viridis(2, alpha = 0.7)[1],
  #xlim = c(-50, 50),
  pch = 19,
  xlab = ifelse(
    !is.null(colnames(data)),
    colnames(data)[dimensions[1]],
    paste0("Var ", dimensions[1])
  ),
  ylab = ifelse(
    !is.null(colnames(data)),
    colnames(data)[dimensions[2]],
    paste0("Var ", dimensions[2])
  ),
  main = "(A)",
  las = 1
)

plot(
  transformed_data[, dimensions],
  bty = "n",
  col = viridis::viridis(2, alpha = 0.7)[1],
  #xlim = c(-50, 50),
  pch = 19,
  xlab = ifelse(
    !is.null(colnames(data)),
    colnames(data)[dimensions[1]],
    paste0("Var ", dimensions[1])
  ),
  ylab = ifelse(
    !is.null(colnames(data)),
    colnames(data)[dimensions[2]],
    paste0("Var ", dimensions[2])
  ),
  main = "(B)",
  las = 1
)
#' # Train the default GAN

torch_manual_seed(20220629)

device <- torch_device(ifelse(cuda_is_available(), "cuda", "cpu"))
device <- torch_device("mps")
trained_gan <- gan_trainer(transformed_data, device = device)

# Sample synthetic data from the trained GAN

trained_gan$settings$eval_dropout <- FALSE

synthetic_data_no_dropout <- sample_synthetic_data(trained_gan, transformer)


trained_gan$settings$eval_dropout <- TRUE

synthetic_data_dropout <- sample_synthetic_data(trained_gan, transformer)
par(mfrow = c(1, 2))
# Plot the results
GAN_update_plot(
  data = data,
  synth_data = synthetic_data_no_dropout,
  main = "(A)"
)

GAN_update_plot(
  data = data,
  synth_data = synthetic_data_dropout,
  main = "(B)"
)

Replication of the Example in Section 4.2

Note that this example is a bit more resource intensive. With a GPU training the following GAN for one epoch takes about 17 minutes. Training on CPU takes considerably longer, on my machine (Apple Macbook Air M1, running R and torch through Rosetta) about 6 hours.

# Create celeba directory in working directory
dir.create()
here::here()

dataset <- torchvision::image_folder_dataset(root = "~/Desktop/celeba",
                                             transform = function(x) {
                                               x = torchvision::transform_to_tensor(x)
                                               x = torchvision::transform_resize(x, size = c(64, 64))
                                               x = torchvision::transform_center_crop(x, c(64, 64))
                                               x = torchvision::transform_normalize(x, c(0.5, 0.5, 0.5), c(0.5, 0.5, 0.5))
                                               return(x)
                                             })

device <- torch_device(ifelse(cuda_is_available(), "cuda", "cpu"))
device <- torch_device("mps")
g_net <- DCGAN_Generator(dropout_rate = 0, noise_dim = 100)$to(device = device)
d_net <- DCGAN_Discriminator(dropout_rate = 0, sigmoid = F)$to(device = device)


g_optim <- torch::optim_adam(g_net$parameters, lr = 0.0002, betas = c(0.5, 0.999))
d_optim <- torch::optim_adam(d_net$parameters, lr = 0.0002, betas = c(0.5, 0.999))

noise_dim <- c(100, 1, 1)
fixed_z <-
  torch::torch_randn(c(16, noise_dim))$to(device = device) 

trained_gan <- gan_trainer(
  data = dataset,
  noise_dim = noise_dim,
  noise_distribution = "normal",
  data_type = "image",
  value_function = "wasserstein",
  generator = g_net,
  generator_optimizer = g_optim,
  discriminator = d_net,
  discriminator_optimizer = d_optim,
  plot_progress = FALSE,
  plot_interval = 10,
  batch_size = 128,
  synthetic_examples = 16,
  device = device,
  eval_dropout = FALSE,
  epochs = 1
)



mneunhoe/RGAN documentation built on Aug. 27, 2023, 7:57 a.m.