gan_trainer | R Documentation |
Provides a function to quickly train a GAN model.
gan_trainer( data, noise_dim = 2, noise_distribution = "normal", value_function = "original", data_type = "tabular", generator = NULL, generator_optimizer = NULL, discriminator = NULL, discriminator_optimizer = NULL, base_lr = 1e-04, ttur_factor = 4, weight_clipper = NULL, batch_size = 50, epochs = 150, plot_progress = FALSE, plot_interval = "epoch", eval_dropout = FALSE, synthetic_examples = 500, plot_dimensions = c(1, 2), device = "cpu" )
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
noise_distribution |
The noise distribution. Expects a function that samples from a distribution and returns a torch_tensor. For convenience "normal" and "uniform" will automatically set a function. Defaults to "normal". |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
data_type |
"tabular" or "image", controls the data type, defaults to "tabular". |
generator |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
generator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
discriminator |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
discriminator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
base_lr |
The base learning rate for the optimizers. Default is 0.0001. Only used if no optimizer is explicitly passed to the trainer. |
ttur_factor |
A multiplier for the learning rate of the discriminator, to implement the two time scale update rule. |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
epochs |
The number of training epochs. Defaults to 150. |
plot_progress |
Monitor training progress with plots. Defaults to FALSE. |
plot_interval |
Number of training steps between plots. Input number of steps or "epoch". Defaults to "epoch". |
eval_dropout |
Should dropout be applied during the sampling of synthetic data? Defaults to FALSE. |
synthetic_examples |
Number of synthetic examples that should be generated. Defaults to 500. For image data e.g. 16 would be more reasonable. |
plot_dimensions |
If you monitor training progress with a plot which dimensions of the data do you want to look at? Defaults to c(1, 2), i.e. the first two columns of the tabular data. |
device |
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu". |
gan_trainer trains the neural networks and returns an object of class trained_RGAN that contains the last generator, discriminator and the respective optimizers, as well as the settings.
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.