vi_monte_carlo_variational_loss: Monte-Carlo approximation of an f-Divergence variational loss

View source: R/vi-functions.R

vi_monte_carlo_variational_lossR Documentation

Monte-Carlo approximation of an f-Divergence variational loss

Description

Variational losses measure the divergence between an unnormalized target distribution p (provided via target_log_prob_fn) and a surrogate distribution q (provided as surrogate_posterior). When the target distribution is an unnormalized posterior from conditioning a model on data, minimizing the loss with respect to the parameters of surrogate_posterior performs approximate posterior inference.

Usage

vi_monte_carlo_variational_loss(
  target_log_prob_fn,
  surrogate_posterior,
  sample_size = 1L,
  importance_sample_size = 1L,
  discrepancy_fn = vi_kl_reverse,
  use_reparametrization = NULL,
  seed = NULL,
  name = NULL
)

Arguments

target_log_prob_fn

function that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample <- surrogate_posterior$sample(sample_size), this will be (in Python) called as target_log_prob_fn(q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].

surrogate_posterior

A tfp$distributions$Distribution instance defining a variational posterior (could be a tfp$distributions$JointDistribution). Crucially, the distribution's log_prob and (if reparameterized) sample methods must directly invoke all ops that generate gradients to the underlying variables. One way to ensure this is to use tfp$util$DeferredTensor to represent any parameters defined as transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation.

sample_size

integer number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.

importance_sample_size

integer number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case it often makes sense to use importance sampling to approximate posterior expectations (see tfp.vi.fit_surrogate_posterior for an example). Default value: 1.

discrepancy_fn

function representing a Csiszar f function in in log-space. That is, discrepancy_fn(log(u)) = f(u), where f is convex in u. Default value: vi_kl_reverse.

use_reparametrization

logical. When NULL (the default), automatically set to: surrogate_posterior.reparameterization_type == tfp$distributions$FULLY_REPARAMETERIZED. When TRUE uses the standard Monte-Carlo average. When FALSE uses the score-gradient trick. (See above for details.) When FALSE, consider using csiszar_vimco.

seed

integer seed for surrogate_posterior$sample.

name

name prefixed to Ops created by this function.

Details

This function defines divergences of the form E_q[discrepancy_fn(log p(z) - log q(z))], sometimes known as f-divergences.

In the special case discrepancy_fn(logu) == -logu (the default vi_kl_reverse), this is the reverse Kullback-Liebler divergence KL[q||p], whose negation applied to an unnormalized p is the widely-used evidence lower bound (ELBO). Other cases of interest available under tfp$vi include the forward KL[p||q] (given by vi_kl_forward(logu) == exp(logu) * logu), total variation distance, Amari alpha-divergences, and more.

Csiszar f-divergences

A Csiszar function f is a convex function from R^+ (the positive reals) to R. The Csiszar f-Divergence is given by:

D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
where x_j ~iid q(X)

For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u) recovers the forward KL[p||q]. These and other functions are available in tfp$vi.

Tricks: Reparameterization and Score-Gradient

When q is "reparameterized", i.e., a diffeomorphic transformation of a parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)), we can swap gradient and expectation, i.e., grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n } where S_n=Avg{s_i} and s_i = f(x_i), x_i ~iid q(X).

However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,

grad[ E_q[f(X)] ]
  = grad[ int dx q(x) f(x) ]
  = int dx grad[ q(x) f(x) ]
  = int dx [ q'(x) f(x) + q(x) f'(x) ]
  = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
  = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
  = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]

Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED it is usually preferable to set use_reparametrization = True.

Example Application: The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,

f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
        <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
        := D_f[p(x, Z), q(Z | x)]

The inequality follows from the fact that the "perspective" of f, i.e., (s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and t is a real. Since the above framework includes the popular Evidence Lower BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework "Evidence Divergence Bound Optimization" (EDBO).

Value

monte_carlo_variational_loss float-like Tensor Monte Carlo approximation of the Csiszar f-Divergence.

References

  • Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.

See Also

Other vi-functions: vi_amari_alpha(), vi_arithmetic_geometric(), vi_chi_square(), vi_csiszar_vimco(), vi_dual_csiszar_function(), vi_fit_surrogate_posterior(), vi_jeffreys(), vi_jensen_shannon(), vi_kl_forward(), vi_kl_reverse(), vi_log1p_abs(), vi_modified_gan(), vi_pearson(), vi_squared_hellinger(), vi_symmetrized_csiszar_function()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.