| vi_monte_carlo_variational_loss | R Documentation |
Variational losses measure the divergence between an unnormalized target
distribution p (provided via target_log_prob_fn) and a surrogate
distribution q (provided as surrogate_posterior). When the
target distribution is an unnormalized posterior from conditioning a model on
data, minimizing the loss with respect to the parameters of
surrogate_posterior performs approximate posterior inference.
vi_monte_carlo_variational_loss( target_log_prob_fn, surrogate_posterior, sample_size = 1L, importance_sample_size = 1L, discrepancy_fn = vi_kl_reverse, use_reparametrization = NULL, seed = NULL, name = NULL )
target_log_prob_fn |
function that takes a set of |
surrogate_posterior |
A |
sample_size |
|
importance_sample_size |
integer number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case it often makes sense to use importance sampling to approximate posterior expectations (see tfp.vi.fit_surrogate_posterior for an example). Default value: 1. |
discrepancy_fn |
function representing a Csiszar |
use_reparametrization |
|
seed |
|
name |
name prefixed to Ops created by this function. |
This function defines divergences of the form
E_q[discrepancy_fn(log p(z) - log q(z))], sometimes known as
f-divergences.
In the special case discrepancy_fn(logu) == -logu (the default
vi_kl_reverse), this is the reverse Kullback-Liebler divergence
KL[q||p], whose negation applied to an unnormalized p is the widely-used
evidence lower bound (ELBO). Other cases of interest available under
tfp$vi include the forward KL[p||q] (given by vi_kl_forward(logu) == exp(logu) * logu),
total variation distance, Amari alpha-divergences, and more.
Csiszar f-divergences
A Csiszar function f is a convex function from R^+ (the positive reals)
to R. The Csiszar f-Divergence is given by:
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
where x_j ~iid q(X)
For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u)
recovers the forward KL[p||q]. These and other functions are available in tfp$vi.
Tricks: Reparameterization and Score-Gradient
When q is "reparameterized", i.e., a diffeomorphic transformation of a
parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)),
we can swap gradient and expectation, i.e.,
grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n } where S_n=Avg{s_i}
and s_i = f(x_i), x_i ~iid q(X).
However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,
grad[ E_q[f(X)] ] = grad[ int dx q(x) f(x) ] = int dx grad[ q(x) f(x) ] = int dx [ q'(x) f(x) + q(x) f'(x) ] = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED it is
usually preferable to set use_reparametrization = True.
Example Application: The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,
f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
<= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
:= D_f[p(x, Z), q(Z | x)]
The inequality follows from the fact that the "perspective" of f, i.e.,
(s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and
t is a real. Since the above framework includes the popular Evidence Lower
BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework
"Evidence Divergence Bound Optimization" (EDBO).
monte_carlo_variational_loss float-like Tensor Monte Carlo
approximation of the Csiszar f-Divergence.
Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.
Other vi-functions:
vi_amari_alpha(),
vi_arithmetic_geometric(),
vi_chi_square(),
vi_csiszar_vimco(),
vi_dual_csiszar_function(),
vi_fit_surrogate_posterior(),
vi_jeffreys(),
vi_jensen_shannon(),
vi_kl_forward(),
vi_kl_reverse(),
vi_log1p_abs(),
vi_modified_gan(),
vi_pearson(),
vi_squared_hellinger(),
vi_symmetrized_csiszar_function()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.