tfd_pixel_cnn: The Pixel CNN++ distribution

View source: R/distributions.R

tfd_pixel_cnnR Documentation

The Pixel CNN++ distribution

Description

Pixel CNN++ (Salimans et al., 2017) models a distribution over image data, parameterized by a neural network. It builds on Pixel CNN and Conditional Pixel CNN, as originally proposed by (van den Oord et al., 2016). The model expresses the joint distribution over pixels as the product of conditional distributions: p(x|h) = prod{ p(x[i] | x[0:i], h) : i=0, ..., d }, in which p(x[i] | x[0:i], h) : i=0, ..., d is the probability of the i-th pixel conditional on the pixels that preceded it in raster order (color channels in RGB order, then left to right, then top to bottom). h is optional additional data on which to condition the image distribution, such as class labels or VAE embeddings. The Pixel CNN++ network enforces the dependency structure among pixels by applying a mask to the kernels of the convolutional layers that ensures that the values for each pixel depend only on other pixels up and to the left. Pixel values are modeled with a mixture of quantized logistic distributions, which can take on a set of distinct integer values (e.g. between 0 and 255 for an 8-bit image). Color intensity v of each pixel is modeled as: v ~ sum{q[i] * quantized_logistic(loc[i], scale[i]) : i = 0, ..., k }, in which k is the number of mixture components and the q[i] are the Categorical probabilities over the components.

Usage

tfd_pixel_cnn(
  image_shape,
  conditional_shape = NULL,
  num_resnet = 5,
  num_hierarchies = 3,
  num_filters = 160,
  num_logistic_mix = 10,
  receptive_field_dims = c(3, 3),
  dropout_p = 0.5,
  resnet_activation = "concat_elu",
  use_weight_norm = TRUE,
  use_data_init = TRUE,
  high = 255,
  low = 0,
  dtype = tf$float32,
  name = "PixelCNN"
)

Arguments

image_shape

3D TensorShape or tuple for the [height, width, channels] dimensions of the image.

conditional_shape

TensorShape or tuple for the shape of the conditional input, or NULL if there is no conditional input.

num_resnet

integer, the number of layers (shown in Figure 2 of https://arxiv.org/abs/1606.05328) within each highest-level block of Figure 2 of https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.

num_hierarchies

integer, the number of hightest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.)

num_filters

integer, the number of convolutional filters.

num_logistic_mix

integer, number of components in the logistic mixture distribution.

receptive_field_dims

tuple, height and width in pixels of the receptive field of the convolutional layers above and to the left of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of https://arxiv.org/abs/1606.05328 shows a receptive field of (3, 5) (the row containing the current pixel is included in the height). The default of (3, 3) was used to produce the results in https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.

dropout_p

float, the dropout probability. Should be between 0 and 1.

resnet_activation

string, the type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.

use_weight_norm

logical, if TRUE then use weight normalization (works only in Eager mode).

use_data_init

logical, if TRUE then use data-dependent initialization (has no effect if use_weight_norm is FALSE).

high

integer, the maximum value of the input data (255 for an 8-bit image).

low

integer, the minimum value of the input data.

dtype

Data type of the Distribution.

name

string, the name of the Distribution.

Value

a distribution instance.

References

See Also

For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().

Other distributions: tfd_autoregressive(), tfd_batch_reshape(), tfd_bates(), tfd_bernoulli(), tfd_beta_binomial(), tfd_beta(), tfd_binomial(), tfd_categorical(), tfd_cauchy(), tfd_chi2(), tfd_chi(), tfd_cholesky_lkj(), tfd_continuous_bernoulli(), tfd_deterministic(), tfd_dirichlet_multinomial(), tfd_dirichlet(), tfd_empirical(), tfd_exp_gamma(), tfd_exp_inverse_gamma(), tfd_exponential(), tfd_gamma_gamma(), tfd_gamma(), tfd_gaussian_process_regression_model(), tfd_gaussian_process(), tfd_generalized_normal(), tfd_geometric(), tfd_gumbel(), tfd_half_cauchy(), tfd_half_normal(), tfd_hidden_markov_model(), tfd_horseshoe(), tfd_independent(), tfd_inverse_gamma(), tfd_inverse_gaussian(), tfd_johnson_s_u(), tfd_joint_distribution_named_auto_batched(), tfd_joint_distribution_named(), tfd_joint_distribution_sequential_auto_batched(), tfd_joint_distribution_sequential(), tfd_kumaraswamy(), tfd_laplace(), tfd_linear_gaussian_state_space_model(), tfd_lkj(), tfd_log_logistic(), tfd_log_normal(), tfd_logistic(), tfd_mixture_same_family(), tfd_mixture(), tfd_multinomial(), tfd_multivariate_normal_diag_plus_low_rank(), tfd_multivariate_normal_diag(), tfd_multivariate_normal_full_covariance(), tfd_multivariate_normal_linear_operator(), tfd_multivariate_normal_tri_l(), tfd_multivariate_student_t_linear_operator(), tfd_negative_binomial(), tfd_normal(), tfd_one_hot_categorical(), tfd_pareto(), tfd_poisson_log_normal_quadrature_compound(), tfd_poisson(), tfd_power_spherical(), tfd_probit_bernoulli(), tfd_quantized(), tfd_relaxed_bernoulli(), tfd_relaxed_one_hot_categorical(), tfd_sample_distribution(), tfd_sinh_arcsinh(), tfd_skellam(), tfd_spherical_uniform(), tfd_student_t_process(), tfd_student_t(), tfd_transformed_distribution(), tfd_triangular(), tfd_truncated_cauchy(), tfd_truncated_normal(), tfd_uniform(), tfd_variational_gaussian_process(), tfd_vector_diffeomixture(), tfd_vector_exponential_diag(), tfd_vector_exponential_linear_operator(), tfd_vector_laplace_diag(), tfd_vector_laplace_linear_operator(), tfd_vector_sinh_arcsinh_diag(), tfd_von_mises_fisher(), tfd_von_mises(), tfd_weibull(), tfd_wishart_linear_operator(), tfd_wishart_tri_l(), tfd_wishart(), tfd_zipf()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.