tfb_affine: Affine bijector

View source: R/bijectors.R

tfb_affineR Documentation

Affine bijector

Description

This Bijector is initialized with shift Tensor and scale arguments, giving the forward operation: Y = g(X) = scale @ X + shift where the scale term is logically equivalent to: scale = scale_identity_multiplier * tf.diag(tf.ones(d)) + tf.diag(scale_diag) + scale_tril + scale_perturb_factor @ diag(scale_perturb_diag) @ tf.transpose([scale_perturb_factor]))

Usage

tfb_affine(
  shift = NULL,
  scale_identity_multiplier = NULL,
  scale_diag = NULL,
  scale_tril = NULL,
  scale_perturb_factor = NULL,
  scale_perturb_diag = NULL,
  adjoint = FALSE,
  validate_args = FALSE,
  name = "affine",
  dtype = NULL
)

Arguments

shift

Floating-point Tensor. If this is set to NULL, no shift is applied.

scale_identity_multiplier

floating point rank 0 Tensor representing a scaling done to the identity matrix. When scale_identity_multiplier = scale_diag = scale_tril = NULL then scale += IdentityMatrix. Otherwise no scaled-identity-matrix is added to scale.

scale_diag

Floating-point Tensor representing the diagonal matrix. scale_diag has shape [N1, N2, ... k], which represents a k x k diagonal matrix. When NULL no diagonal term is added to scale.

scale_tril

Floating-point Tensor representing the lower triangular matrix. scale_tril has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When NULL no scale_tril term is added to scale. The upper triangular elements above the diagonal are ignored.

scale_perturb_factor

Floating-point Tensor representing factor matrix with last two dimensions of shape (k, r) When NULL, no rank-r update is added to scale.

scale_perturb_diag

Floating-point Tensor representing the diagonal matrix. scale_perturb_diag has shape [N1, N2, ... r], which represents an r x r diagonal matrix. When NULL low rank updates will take the form scale_perturb_factor * scale_perturb_factor.T.

adjoint

Logical indicating whether to use the scale matrix as specified or its adjoint. Default value: FALSE.

validate_args

Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed.

name

name prefixed to Ops created by this class.

dtype

tf$DType to prefer when converting args to Tensors. Else, we fall back to a common dtype inferred from the args, finally falling back to float32.

Details

If NULL of scale_identity_multiplier, scale_diag, or scale_tril are specified then scale += IdentityMatrix Otherwise specifying a scale argument has the semantics of scale += Expand(arg), i.e., scale_diag != NULL means scale += tf$diag(scale_diag).

Value

a bijector instance.

See Also

For usage examples see tfb_forward(), tfb_inverse(), tfb_inverse_log_det_jacobian().

Other bijectors: tfb_absolute_value(), tfb_affine_linear_operator(), tfb_affine_scalar(), tfb_ascending(), tfb_batch_normalization(), tfb_blockwise(), tfb_chain(), tfb_cholesky_outer_product(), tfb_cholesky_to_inv_cholesky(), tfb_correlation_cholesky(), tfb_cumsum(), tfb_discrete_cosine_transform(), tfb_expm1(), tfb_exp(), tfb_ffjord(), tfb_fill_scale_tri_l(), tfb_fill_triangular(), tfb_glow(), tfb_gompertz_cdf(), tfb_gumbel_cdf(), tfb_gumbel(), tfb_identity(), tfb_inline(), tfb_invert(), tfb_iterated_sigmoid_centered(), tfb_kumaraswamy_cdf(), tfb_kumaraswamy(), tfb_lambert_w_tail(), tfb_masked_autoregressive_default_template(), tfb_masked_autoregressive_flow(), tfb_masked_dense(), tfb_matrix_inverse_tri_l(), tfb_matvec_lu(), tfb_normal_cdf(), tfb_ordered(), tfb_pad(), tfb_permute(), tfb_power_transform(), tfb_rational_quadratic_spline(), tfb_rayleigh_cdf(), tfb_real_nvp_default_template(), tfb_real_nvp(), tfb_reciprocal(), tfb_reshape(), tfb_scale_matvec_diag(), tfb_scale_matvec_linear_operator(), tfb_scale_matvec_lu(), tfb_scale_matvec_tri_l(), tfb_scale_tri_l(), tfb_scale(), tfb_shifted_gompertz_cdf(), tfb_shift(), tfb_sigmoid(), tfb_sinh_arcsinh(), tfb_sinh(), tfb_softmax_centered(), tfb_softplus(), tfb_softsign(), tfb_split(), tfb_square(), tfb_tanh(), tfb_transform_diagonal(), tfb_transpose(), tfb_weibull_cdf(), tfb_weibull()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.