gan_update_step | R Documentation |
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
gan_update_step( data, batch_size, noise_dim, sample_noise, device = "cpu", g_net, d_net, g_optim, d_optim, value_function, weight_clipper )
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
sample_noise |
A function to sample noise to a torch::tensor |
device |
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu". |
g_net |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
d_net |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
g_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
d_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
A function
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.