#' Modulated Convolution
#'
#' @param x
#' @param w
#' @param s
#' @param demodulate
#' @param padding
#' @param input_gain
#'
#' @return
#' @export
#' @importFrom zeallot `%<-%`
#'
#' @examples
modulated_conv2d <- function(
x, # Input tensor: [batch_size, in_channels, in_height, in_width]
w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
s, # Style tensor: [batch_size, in_channels]
demodulate = TRUE, # Apply weight demodulation?
padding = 0, # Padding: int or [padH, padW]
input_gain = NULL # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
) {
batch_size <- x$shape[1]
out_channels <- in_channels <- kh <- kw <- NULL
c(out_channels, in_channels, kh, kw) %<-% w$shape
assert_shape(w, c(out_channels, in_channels, kh, kw)) # [OIkk]
assert_shape(x, c(batch_size, in_channels, NA, NA)) # [NIHW]
assert_shape(s, c(batch_size, in_channels)) # [NI]
# Pre-normalize inputs.
if(demodulate) {
w <- w * w$square()$mean(c(2, 3, 4), keepdim = TRUE)$rsqrt()
s <- s * s$square()$mean()$rsqrt()
}
# Modulate weights.
w <- w$unsqueeze(1) # [NOIkk]
w <- w * s$unsqueeze(2)$unsqueeze(4)$unsqueeze(5) # [NOIkk]
# Demodulate weights.
if(demodulate) {
dcoefs <- (w$square()$sum(dim = c(3,4,5)) + 1e-8)$rsqrt() # [NO]
w <- w * dcoefs$unsqueeze(3)$unsqueeze(4)$unsqueeze(5) # [NOIkk]
}
# Apply input scaling.
if(!is.null(input_gain)) {
input_gain <- input_gain$expand(c(batch_size, in_channels)) # [NI]
w <- w * input_gain$unsqueeze(2)$unsqueeze(4)$unsqueeze(5) # [NOIkk]
}
# Execute as one fused op using grouped convolution.
x <- x$reshape(c(1, -1, x$shape[3:4])) ## *x$shape[2:] * means to 'unpack'
w <- w$reshape(c(-1, in_channels, kh, kw))
x <- conv2d_gradfix(input = x, weight = w$to(x$dtype), padding = padding, groups = batch_size)
x = x$reshape(c(batch_size, -1, x$shape[3:4]))
return(x)
}
## as far as I can tell, similar to nn_linear but uses a custom op to add an activation to the bias
FullyConnectedLayer <- nn_module(
initialize = function(
in_features, # Number of input features.
out_features, # Number of output features.
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
bias = TRUE, # Apply additive bias before the activation function?
lr_multiplier = 1, # Learning rate multiplier.
weight_init = 1, # Initial standard deviation of the weight tensor.
bias_init = 0 # Initial value of the additive bias.
) {
self$in_features <- in_features
self$out_features <- out_features
self$activation <- activation
self$weight <- nn_parameter(torch_randn(out_features, in_features) * (weight_init / lr_multiplier))
if(bias) {
bias_init <- array(bias_init, dim = out_features)
self$bias <- nn_parameter(torch_tensor(bias_init / lr_multiplier))
}
self$weight_gain <- lr_multiplier / sqrt(in_features)
self$bias_gain <- lr_multiplier
},
forward = function(x) {
w <- self$weight$to(x$dtype) * self$weight_gain
b <- self$bias
if(!is.null(b)) {
b = b$to(x$dtype)
}
if(self$bias_gain != 1) {
b <- b * self$bias_gain
}
if(self$activation == 'linear' & !is.null(b)) {
x <- torch_addmm(b$unsqueeze(1), x, w$t())
} else {
x <- x$matmul(w$t())
x <- bias_act(x, b, dim = 2, act = self$activation) ## bias_act: this is custom cuda op, must be imported
}
return(x)
},
extra_repr = function() {
return(glue::glue('in_features={self$in_features}, out_features={self$out_features}, activation={self$activation}'))
}
)
#----------------------------------------------------------------------------
MappingNetwork <- nn_module(
initialize = function(
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.998 # Decay for tracking the moving average of W during training.
) {
self$z_dim <- z_dim
self$c_dim <- c_dim
self$w_dim <- w_dim
self$num_ws <- num_ws
self$num_layers <- num_layers
self$w_avg_beta <- w_avg_beta
# Construct layers.
if(self$c_dim > 0) {
self$embed <- FullyConnectedLayer(self$c_dim, self$w_dim)
}
features <- c(self$z_dim + if(self$c_dim > 0) self$w_dim else 0, rep(self$w_dim, self$num_layers))
#features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
layers <- purrr::map2(features[1:(length(features) - 1)], features[2:(length(features))],
~FullyConnectedLayer(.x, .y, activation = 'lrelu', lr_multiplier = lr_multiplier))
for(idx in seq_along(layers)) {
self[[glue::glue("fc{idx - 1}")]] <- layers[[idx]]
}
# for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
# layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
# layer_name <- glue::glue('fc{idx}')
self$w_avg <- nn_buffer(torch_zeros(w_dim))
},
forward = function(z, c, truncation_psi = 1, truncation_cutoff = NULL, update_emas = FALSE) {
assert_shape(z, c(NA, self$z_dim))
if(is.null(truncation_cutoff)) {
truncation_cutoff <- self$num_ws
}
# Embed, normalize, and concatenate inputs.
x <- z$to(torch_float32())
x <- x * (x$square()$mean(2, keepdim = TRUE) + 1e-8)$rsqrt()
if(self$c_dim > 0) {
assert_shape(c, c(NA, self$c_dim))
y <- self$embed(c$to(torch_float32()))
y <- y * (y$square()$mean(2, keepdim = TRUE) + 1e-8)$rsqrt()
if(!is.null(x)) {
x <- torch_cat(list(x, y), dim = 2)
} else {
x <- y
}
}
# Execute layers.
for(idx in seq_along(self$num_layers)) {
x <- self[[glue::glue('fc{idx}')]](x)
}
# Update moving average of W.
if(update_emas) {
self$w_avg$copy_(x$detach()$mean(dim = 1)$lerp(self$w_avg, self$w_avg_beta))
}
# Broadcast and apply truncation.
x <- x$unsqueeze(2)$`repeat`(c(1, self$num_ws, 1))
if(truncation_psi != 1) {
x[ , 1:truncation_cutoff] <- self$w_avg$lerp(x[ , 1:truncation_cutoff], truncation_psi)
}
return(x)
},
extra_repr = function(){
return(glue::glue('z_dim={self$z_dim}, c_dim={self$c_dim}, w_dim={self$w_dim}, num_ws={self$num_ws}'))
}
)
#----------------------------------------------------------------------------
SynthesisInput <- nn_module(
initialize = function(w_dim, # Intermediate latent (W) dimensionality.
channels, # Number of output channels.
size, # Output spatial size: int or c(width, height)
sampling_rate, # Output sampling rate.
bandwidth # Output bandwidth.
) {
self$w_dim <- w_dim
self$channels <- channels
self$size <- array(size, dim = 2)
self$sampling_rate <- sampling_rate
self$bandwidth <- bandwidth
# Draw random frequencies from uniform 2D disc.
freqs <- torch_randn(c(self$channels, 2))
radii <- freqs$square()$sum(dim = 1, keepdim = TRUE)$sqrt()
freqs <- freqs / (radii * radii$square()$exp()$pow(0.25))
freqs <- freqs * bandwidth
phases <- torch_rand(self$channels) - 0.5
# Setup parameters and buffers.
self$weight <- nn_parameter(torch_randn(c(self$channels, self$channels)))
self$affine <- FullyConnectedLayer(w_dim, 4, weight_init = 0, bias_init = c(1, 0, 0, 0))
self$transform <- nn_buffer(torch_eye(3, 3)) # User-specified inverse transform wrt. resulting image.
self$freqs <- nn_buffer(freqs)
self$phases <- nn_buffer(phases)
},
forward = function(w) {
# Introduce batch dimension.
transforms <- self$transform$unsqueeze(1) # [batch, row, col]
freqs <- self$freqs$unsqueeze(1) # [batch, channel, xy]
phases <- self$phases$unsqueeze(1) # [batch, channel]
# Apply learned transformation.
t <- self$affine(w) # t = (r_c, r_s, t_x, t_y)
t <- t / t[ , 1:2]$norm(dim = 2, keepdim = TRUE) # t' = (r'_c, r'_s, t'_x, t'_y)
m_r <- torch_eye(3, device = w$device)$unsqueeze(1)$`repeat`(c(w$shape[1], 1, 1)) # Inverse rotation wrt. resulting image.
m_r[ , 1, 1] <- t[ , 1] # r'_c
m_r[ , 1, 2] <- -t[ , 2] # r'_s
m_r[ , 2, 1] <- t[ , 2] # r'_s
m_r[ , 2, 2] <- t[ , 1] # r'_c
m_t <- torch_eye(3, device = w$device)$unsqueeze(1)$`repeat`(c(w$shape[1], 1, 1)) # Inverse translation wrt. resulting image.
m_t[ , 1, 3] = -t[ , 3] # t'_x
m_t[ , 2, 3] = -t[ , 4] # t'_y
transforms <- m_r %*% m_t %*% transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
# Transform frequencies.
phases <- phases + (freqs %*% transforms[ , 1:2, 3:Inf])$squeeze(3)
freqs <- freqs %*% transforms[ , 1:2, 1:2]
# Dampen out-of-band frequencies that may occur due to the user-specified transform.
amplitudes <- (1 - (freqs$norm(dim = 3) - self$bandwidth) / (self$sampling_rate / 2 - self$bandwidth))$clamp(0, 1)
# Construct sampling grid.
theta <- torch_eye(2, 3, device = w$device)
theta[1, 1] <- 0.5 * self$size[1] / self$sampling_rate
theta[2, 2] <- 0.5 * self$size[2] / self$sampling_rate
grids <- nnf_affine_grid(theta$unsqueeze(1), c(1, 1, self$size[2], self$size[1]), align_corners = FALSE)
# Compute Fourier features.
x <- (grids$unsqueeze(4) %*% freqs$permute(c(1, 3, 2))$unsqueeze(2)$unsqueeze(3))$squeeze(4) # [batch, height, width, channel]
x <- x + phases$unsqueeze(2)$unsqueeze(3)
x <- torch_sin(x * (pi * 2))
x = x * amplitudes$unsqueeze(2)$unsqueeze(3)
# Apply trainable mapping.
weight <- self$weight / sqrt(self$channels)
x <- x %*% weight$t()
# Ensure correct shape.
x <- x$permute(c(1, 4, 2, 3)) # [batch, channel, height, width]
assert_shape(x, c(w$shape[1], self$channels, as.integer(self$size[2]), as.integer(self$size[1])))
return(x)
},
extra_repr = function() {
return(glue::glue('w_dim={self$w_dim}, channels={self$channels}, size={self$size},\nsampling_rate={self$sampling_rate}, bandwidth={self$bandwidth}'))
}
)
SynthesisLayer <- nn_module(
initialize = function(
w_dim, # Intermediate latent (W) dimensionality.
is_torgb, # Is this the final ToRGB layer?
is_critically_sampled, # Does this layer use critical sampling?
use_fp16, # Does this layer use FP16?
# Input & output specifications.
in_channels, # Number of input channels.
out_channels, # Number of output channels.
in_size, # Input spatial size: int or [width, height].
out_size, # Output spatial size: int or [width, height].
in_sampling_rate, # Input sampling rate (s).
out_sampling_rate, # Output sampling rate (s).
in_cutoff, # Input cutoff frequency (f_c).
out_cutoff, # Output cutoff frequency (f_c).
in_half_width, # Input transition band half-width (f_h).
out_half_width, # Output Transition band half-width (f_h).
# Hyperparameters.
conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
use_radial_filters = FALSE, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
magnitude_ema_beta = 0.999 # Decay rate for the moving average of input magnitudes.
) {
self$w_dim <- w_dim
self$is_torgb <- is_torgb
self$is_critically_sampled <- is_critically_sampled
self$use_fp16 <- use_fp16
self$in_channels <- in_channels
self$out_channels <- out_channels
self$in_size <- array(in_size, dim = 2) #np.broadcast_to(np.asarray(in_size), [2])
self$out_size <- array(out_size, dim = 2) #np.broadcast_to(np.asarray(out_size), [2])
self$in_sampling_rate <- in_sampling_rate
self$out_sampling_rate <- out_sampling_rate
self$tmp_sampling_rate <- max(in_sampling_rate, out_sampling_rate) * (if(is_torgb) 1 else lrelu_upsampling)
self$in_cutoff <- in_cutoff
self$out_cutoff <- out_cutoff
self$in_half_width <- in_half_width
self$out_half_width <- out_half_width
self$conv_kernel <- if(is_torgb) 1 else conv_kernel
self$conv_clamp <- conv_clamp
self$magnitude_ema_beta <- magnitude_ema_beta
# Setup parameters and buffers.
self$affine <- FullyConnectedLayer(self$w_dim, self$in_channels, bias_init = 1)
self$weight <- nn_parameter(torch_randn(c(self$out_channels, self$in_channels, self$conv_kernel, self$conv_kernel)))
self$bias <- nn_parameter(torch_zeros(self$out_channels))
self$magnitude_ema <- nn_buffer(torch_scalar_tensor(1.0))
# Design upsampling filter.
self$up_factor <- as.integer(round(self$tmp_sampling_rate / self$in_sampling_rate))
assertthat::are_equal(self$in_sampling_rate * self$up_factor, self$tmp_sampling_rate)
self$up_taps <- if(self$up_factor > 1 & !self$is_torgb) filter_size * self$up_factor else 1
up_filter <- self$design_lowpass_filter(
numtaps = self$up_taps, cutoff = self$in_cutoff,
width = self$in_half_width * 2, fs = self$tmp_sampling_rate)
if(up_filter$numel() > 0) {
self$up_filter <- nn_buffer(up_filter)
} else {
self$up_filter <- NULL
}
# Design downsampling filter.
self$down_factor = as.integer(round(self$tmp_sampling_rate / self$out_sampling_rate))
assertthat::are_equal(self$out_sampling_rate * self$down_factor, self$tmp_sampling_rate)
self$down_taps <- if(self$down_factor > 1 & !self$is_torgb) filter_size * self$down_factor else 1
self$down_radial <- use_radial_filters & !self$is_critically_sampled
down_filter <- self$design_lowpass_filter(
numtaps = self$down_taps, cutoff = self$out_cutoff, width = self$out_half_width * 2, fs = self$tmp_sampling_rate, radial = self$down_radial)
if(down_filter$numel() > 0) {
self$down_filter <- nn_buffer(down_filter)
} else {
self$down_filter <- NULL
}
# Compute padding.
pad_total <- (self$out_size - 1) * self$down_factor + 1 # Desired output size before downsampling.
pad_total <- pad_total - (self$in_size + self$conv_kernel - 1) * self$up_factor # Input size after upsampling.
pad_total <- pad_total + self$up_taps + self$down_taps - 2 # Size reduction caused by the filters.
pad_lo <- (pad_total + self$up_factor) %/% 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
pad_hi <- pad_total - pad_lo
self$padding <- c(as.integer(pad_lo[1]), as.integer(pad_hi[1]), as.integer(pad_lo[2]), as.integer(pad_hi[2]))
},
forward = function(x, w, noise_mode = 'random', force_fp32 = FALSE, update_emas = FALSE) {
assertthat::assert_that(noise_mode %in% c('random', 'const', 'none')) # unused
assert_shape(x, c(NA, self$in_channels, as.integer(self$in_size[2]), as.integer(self$in_size[1])))
assert_shape(w, c(x$shape[1], self$w_dim))
# Track input magnitude.
if(update_emas) {
magnitude_cur <- x$detach()$to(torch_float32())$square()$mean()
self$magnitude_ema$copy_(magnitude_cur$lerp(self$magnitude_ema, self$magnitude_ema_beta))
}
input_gain <- self$magnitude_ema$rsqrt()
# Execute affine layer.
styles <- self$affine(w)
if(self$is_torgb) {
weight_gain <- 1 / sqrt(self$in_channels * (self$conv_kernel^2))
styles <- styles * weight_gain
}
# Execute modulated conv2d.
dtype = if(self$use_fp16 & !force_fp32 & x$device$type == 'cuda') torch_float16() else torch_float32()
x = modulated_conv2d(x = x$to(dtype), w = self$weight, s = styles,
padding = self$conv_kernel - 1, demodulate = (!self$is_torgb), input_gain = input_gain)
# Execute bias, filtered leaky ReLU, and clamping.
gain <- if(self$is_torgb) 1 else sqrt(2)
slope <- if(self$is_torgb) 1 else 0.2
x <- filtered_lrelu(x = x, fu = self$up_filter, fd = self$down_filter, b = self$bias$to(x$dtype),
up = self$up_factor, down = self$down_factor, padding = self$padding,
gain = gain, slope = slope, clamp = self$conv_clamp)
# Ensure correct shape and dtype.
assert_shape(x, c(NA, self$out_channels, as.integer(self$out_size[2]), as.integer(self$out_size[1])))
assertthat::assert_that(x$dtype == dtype)
return(x)
},
design_lowpass_filter = function(numtaps, cutoff, width, fs, radial = FALSE) {
assertthat::assert_that(numtaps >= 1)
# Identity filter.
if(numtaps == 1) {
return(torch_empty(0))
}
# Separable Kaiser low-pass filter.
if(!radial) {
f <- scipy_signal_firwin(numtaps = numtaps, cutoff = cutoff, width = width, fs = fs)
return(torch_tensor(f, dtype = torch_float32()))
}
# Radially symmetric jinc-based filter.
x <- (seq(0, numtaps - 1) - (numtaps - 1) / 2) / fs
r <- rlang::exec(hypot, !!!(meshgrid(x, x) %>% setNames(c("x1", "x2"))))
#r = np.hypot(*np.meshgrid(x, x))
f <- besselJ(2 * cutoff * (pi * r), 1) / (pi * r)
#f <- scipy.special.j1()
beta <- kaiser_beta(kaiser_atten(numtaps, width / (fs / 2)))
w <- signal::kaiser(numtaps, beta)
f <- f * outer(w, w)
f <- f / sum(f)
return(torch_tensor(f, dtype = torch_float32()))
},
extra_repr = function() {
return(glue::glue(paste(
'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}',
sep = "\n")))
}
)
SynthesisNetwork <- nn_module(
initialize = function(
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output image resolution.
img_channels, # Number of color channels.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
num_critical = 2, # Number of critically sampled layers at the end.
first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
margin_size = 10, # Number of additional pixels outside the image.
output_scale = 0.25, # Scale factor for the output image.
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
... # Arguments for SynthesisLayer. (as a list)
) {
self$w_dim <- w_dim
self$num_ws <- num_layers + 2
self$img_resolution <- img_resolution
self$img_channels <- img_channels
self$num_layers <- num_layers
self$num_critical <- num_critical
self$margin_size <- margin_size
self$output_scale <- output_scale
self$num_fp16_res <- num_fp16_res
# Geometric progression of layer cutoffs and min. stopbands.
last_cutoff <- self$img_resolution / 2 # f_{c,N}
last_stopband <- last_cutoff * last_stopband_rel # f_{t,N}
exponents <- pmin((seq_len(self$num_layers + 1) - 1) / (self$num_layers - self$num_critical), 1)
cutoffs <- first_cutoff * (last_cutoff / first_cutoff)^exponents # f_c[i]
stopbands <- first_stopband * (last_stopband / first_stopband)^exponents # f_t[i]
# Compute remaining layer parameters.
sampling_rates <- exp2(ceiling(log2(pmin(stopbands * 2, self$img_resolution)))) # s[i]
half_widths <- pmax(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
sizes <- sampling_rates + self$margin_size * 2
sizes[(length(sizes)-1):length(sizes)] = self$img_resolution
channels <- as.integer(round(pmin((channel_base / 2) / cutoffs, channel_max)))
channels[length(channels)] <- self$img_channels
# Construct layers.
self$input <- SynthesisInput(
w_dim = self$w_dim, channels = as.integer(channels[1]), size = as.integer(sizes[1]),
sampling_rate = sampling_rates[1], bandwidth = cutoffs[1])
self$layer_names = character()
for(idx in seq_len(self$num_layers + 1)) {
prev <- max(idx - 1, 1)
is_torgb <- (idx == (self$num_layers + 1))
is_critically_sampled <- (idx >= self$num_layers + 1 - self$num_critical)
use_fp16 <- (sampling_rates[idx] * (2^self$num_fp16_res) > self$img_resolution)
layer <- SynthesisLayer(
w_dim = self$w_dim, is_torgb = is_torgb, is_critically_sampled = is_critically_sampled, use_fp16 = use_fp16,
in_channels = as.integer(channels[prev]), out_channels = as.integer(channels[idx]),
in_size = as.integer(sizes[prev]), out_size = as.integer(sizes[idx]),
in_sampling_rate = as.integer(sampling_rates[prev]), out_sampling_rate = as.integer(sampling_rates[idx]),
in_cutoff = cutoffs[prev], out_cutoff = cutoffs[idx],
in_half_width = half_widths[prev], out_half_width = half_widths[idx],
...)
name <- glue::glue('L{idx - 1}_{layer$out_size[1]}_{layer$out_channels}')
self[[name]] <- layer
self$layer_names <- c(self$layer_names, name)
}
},
forward = function(ws, ...) {
assert_shape(ws, c(NA, self$num_ws, self$w_dim))
ws <- ws$to(torch_float32())$unbind(dim = 2)
# Execute layers.
x <- self$input(ws[[1]])
# for name, w in zip(self.layer_names, ws[1:]):
# x = getattr(self, name)(x, w, **layer_kwargs)
names_w <- purrr::transpose(list(name = self$layer_names, w = ws[2:length(ws)]))
for(i in names_w) {
x <- self[[i$name]](x, i$w, ...)
}
if(self$output_scale != 1) {
x <- x * self$output_scale
}
# Ensure correct shape and dtype.
assert_shape(x, c(NA, self$img_channels, self$img_resolution, self$img_resolution))
x <- x$to(torch_float32())
return(x)
},
extra_repr = function() {
return(glue::glue(paste(
'w_dim={self$w_dim}, num_ws={self$num_ws},',
'img_resolution={self$img_resolution}, img_channels={self$img_channels},',
'num_layers={self$num_layers}, num_critical={self$num_critical},',
'margin_size={self$margin_size}, num_fp16_res={self$num_fp16_res}',
sep = "\n")))
}
)
#----------------------------------------------------------------------------
#' @export
Generator <- nn_module(
initialize = function(
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
mapping_kwargs = NULL, # Arguments for MappingNetwork (as a named list).
... # Arguments for SynthesisNetwork.
) {
self$z_dim <- z_dim
self$c_dim <- c_dim
self$w_dim <- w_dim
self$img_resolution <- img_resolution
self$img_channels <- img_channels
self$synthesis <- SynthesisNetwork(w_dim = w_dim, img_resolution = img_resolution, img_channels = img_channels, ...)
self$num_ws = self$synthesis$num_ws
if(!is.null(mapping_kwargs)) {
self$mapping <- rlang::exec(MappingNetwork,
z_dim = z_dim, c_dim = c_dim, w_dim = w_dim, num_ws = self$num_ws,
!!!mapping_kwargs)
} else {
self$mapping <- MappingNetwork(z_dim = z_dim, c_dim = c_dim, w_dim = w_dim,
num_ws = self$num_ws)
}
},
forward = function(z, c, truncation_psi = 1, truncation_cutoff = NULL, update_emas = FALSE, ...) {
ws <- self$mapping(z, c, truncation_psi = truncation_psi, truncation_cutoff = truncation_cutoff,
update_emas = update_emas)
img <- self$synthesis(ws, update_emas = update_emas, ...)
return(img)
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.