vi_monte_carlo_variational_loss | R Documentation |
Variational losses measure the divergence between an unnormalized target
distribution p
(provided via target_log_prob_fn
) and a surrogate
distribution q
(provided as surrogate_posterior
). When the
target distribution is an unnormalized posterior from conditioning a model on
data, minimizing the loss with respect to the parameters of
surrogate_posterior
performs approximate posterior inference.
vi_monte_carlo_variational_loss( target_log_prob_fn, surrogate_posterior, sample_size = 1L, importance_sample_size = 1L, discrepancy_fn = vi_kl_reverse, use_reparametrization = NULL, seed = NULL, name = NULL )
target_log_prob_fn |
function that takes a set of |
surrogate_posterior |
A |
sample_size |
|
importance_sample_size |
integer number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case it often makes sense to use importance sampling to approximate posterior expectations (see tfp.vi.fit_surrogate_posterior for an example). Default value: 1. |
discrepancy_fn |
function representing a Csiszar |
use_reparametrization |
|
seed |
|
name |
name prefixed to Ops created by this function. |
This function defines divergences of the form
E_q[discrepancy_fn(log p(z) - log q(z))]
, sometimes known as
f-divergences.
In the special case discrepancy_fn(logu) == -logu
(the default
vi_kl_reverse
), this is the reverse Kullback-Liebler divergence
KL[q||p]
, whose negation applied to an unnormalized p
is the widely-used
evidence lower bound (ELBO). Other cases of interest available under
tfp$vi
include the forward KL[p||q]
(given by vi_kl_forward(logu) == exp(logu) * logu
),
total variation distance, Amari alpha-divergences, and more.
Csiszar f-divergences
A Csiszar function f
is a convex function from R^+
(the positive reals)
to R
. The Csiszar f-Divergence is given by:
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), where x_j ~iid q(X)
For example, f = lambda u: -log(u)
recovers KL[q||p]
, while f = lambda u: u * log(u)
recovers the forward KL[p||q]
. These and other functions are available in tfp$vi
.
Tricks: Reparameterization and Score-Gradient
When q is "reparameterized", i.e., a diffeomorphic transformation of a
parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)
),
we can swap gradient and expectation, i.e.,
grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }
where S_n=Avg{s_i}
and s_i = f(x_i), x_i ~iid q(X)
.
However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,
grad[ E_q[f(X)] ] = grad[ int dx q(x) f(x) ] = int dx grad[ q(x) f(x) ] = int dx [ q'(x) f(x) + q(x) f'(x) ] = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED
it is
usually preferable to set use_reparametrization = True
.
Example Application: The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,
f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] := D_f[p(x, Z), q(Z | x)]
The inequality follows from the fact that the "perspective" of f
, i.e.,
(s, t) |-> t f(s / t))
, is convex in (s, t)
when s/t in domain(f)
and
t
is a real. Since the above framework includes the popular Evidence Lower
BOund (ELBO) as a special case, i.e., f(u) = -log(u)
, we call this framework
"Evidence Divergence Bound Optimization" (EDBO).
monte_carlo_variational_loss float
-like Tensor
Monte Carlo
approximation of the Csiszar f-Divergence.
Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.
Other vi-functions:
vi_amari_alpha()
,
vi_arithmetic_geometric()
,
vi_chi_square()
,
vi_csiszar_vimco()
,
vi_dual_csiszar_function()
,
vi_fit_surrogate_posterior()
,
vi_jeffreys()
,
vi_jensen_shannon()
,
vi_kl_forward()
,
vi_kl_reverse()
,
vi_log1p_abs()
,
vi_modified_gan()
,
vi_pearson()
,
vi_squared_hellinger()
,
vi_symmetrized_csiszar_function()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.