gan_trainer: gan_trainer

View source: R/gan-trainer.R

gan_trainerR Documentation

gan_trainer

Description

Provides a function to quickly train a GAN model.

Usage

gan_trainer(
  data,
  noise_dim = 2,
  noise_distribution = "normal",
  value_function = "original",
  data_type = "tabular",
  generator = NULL,
  generator_optimizer = NULL,
  discriminator = NULL,
  discriminator_optimizer = NULL,
  base_lr = 1e-04,
  ttur_factor = 4,
  weight_clipper = NULL,
  batch_size = 50,
  epochs = 150,
  plot_progress = FALSE,
  plot_interval = "epoch",
  eval_dropout = FALSE,
  synthetic_examples = 500,
  plot_dimensions = c(1, 2),
  track_loss = FALSE,
  plot_loss = FALSE,
  device = "cpu"
)

Arguments

data

Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset.

noise_dim

The dimensions of the GAN noise vector z. Defaults to 2.

noise_distribution

The noise distribution. Expects a function that samples from a distribution and returns a torch_tensor. For convenience "normal" and "uniform" will automatically set a function. Defaults to "normal".

value_function

The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original".

data_type

"tabular" or "image", controls the data type, defaults to "tabular".

generator

The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

generator_optimizer

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr).

discriminator

The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.

discriminator_optimizer

The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr * ttur_factor).

base_lr

The base learning rate for the optimizers. Default is 0.0001. Only used if no optimizer is explicitly passed to the trainer.

ttur_factor

A multiplier for the learning rate of the discriminator, to implement the two time scale update rule.

weight_clipper

The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training.

batch_size

The number of training samples selected into the mini batch for training. Defaults to 50.

epochs

The number of training epochs. Defaults to 150.

plot_progress

Monitor training progress with plots. Defaults to FALSE.

plot_interval

Number of training steps between plots. Input number of steps or "epoch". Defaults to "epoch".

eval_dropout

Should dropout be applied during the sampling of synthetic data? Defaults to FALSE.

synthetic_examples

Number of synthetic examples that should be generated. Defaults to 500. For image data e.g. 16 would be more reasonable.

plot_dimensions

If you monitor training progress with a plot which dimensions of the data do you want to look at? Defaults to c(1, 2), i.e. the first two columns of the tabular data.

track_loss

Store the training losses as additional output. Defaults to FALSE.

plot_loss

Monitor the losses during training with plots. Defaults to FALSE.

device

Input on which device (e.g. "cpu", "cuda", or "mps") training should be done. Defaults to "cpu".

Value

gan_trainer trains the neural networks and returns an object of class trained_RGAN that contains the last generator, discriminator and the respective optimizers, as well as the settings.

Examples

## Not run: 
# Before running the first time the torch backend needs to be installed
torch::install_torch()
# Load data
data <- sample_toydata()
# Build new transformer
transformer <- data_transformer$new()
# Fit transformer to data
transformer$fit(data)
# Transform data and store as new object
transformed_data <-  transformer$transform(data)
# Train the default GAN
trained_gan <- gan_trainer(transformed_data)
# Sample synthetic data from the trained GAN
synthetic_data <- sample_synthetic_data(trained_gan, transformer)
# Plot the results
GAN_update_plot(data = data,
synth_data = synthetic_data,
main = "Real and Synthetic Data after Training")

## End(Not run)

mneunhoe/RGAN documentation built on Aug. 27, 2023, 7:57 a.m.