vi_csiszar_vimco: Use VIMCO to lower the variance of the gradient of...

View source: R/vi-functions.R

vi_csiszar_vimcoR Documentation

Use VIMCO to lower the variance of the gradient of csiszar_function(Avg(logu))

Description

This function generalizes VIMCO (Mnih and Rezende, 2016) to Csiszar f-Divergences.

Usage

vi_csiszar_vimco(
  f,
  p_log_prob,
  q,
  num_draws,
  num_batch_draws = 1,
  seed = NULL,
  name = NULL
)

Arguments

f

function representing a Csiszar-function in log-space.

p_log_prob

function representing the natural-log of the probability under distribution p. (In variational inference p is the joint distribution.)

q

tfd$Distribution-like instance; must implement: sample(n, seed), and log_prob(x). (In variational inference q is the approximate posterior distribution.)

num_draws

Integer scalar number of draws used to approximate the f-Divergence expectation.

num_batch_draws

Integer scalar number of draws used to approximate the f-Divergence expectation.

seed

integer seed for q$sample.

name

String prefixed to Ops created by this function.

Details

Note: if q.reparameterization_type = tfd.FULLY_REPARAMETERIZED, consider using monte_carlo_csiszar_f_divergence.

The VIMCO loss is:

vimco = f(Avg{logu[i] : i=0,...,m-1})
where,
logu[i] = log( p(x, h[i]) / q(h[i] | x) )
h[i] iid~ q(H | x)

Interestingly, the VIMCO gradient is not the naive gradient of vimco. Rather, it is characterized by:

grad[vimco] - variance_reducing_term

where,

variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) #' : i=0, ..., m-1 }
h[j;i] =  u[j]  for j!=i,  GeometricAverage{ u[k] : k!=i} for j==i

(We omitted stop_gradient for brevity. See implementation for more details.) The Avg{h[j;i] : j} term is a kind of "swap-out average" where the i-th element has been replaced by the leave-i-out Geometric-average.

This implementation prefers numerical precision over efficiency, i.e., O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape)). (The constant may be fairly large, perhaps around 12.)

Value

vimco The Csiszar f-Divergence generalized VIMCO objective

References

See Also

Other vi-functions: vi_amari_alpha(), vi_arithmetic_geometric(), vi_chi_square(), vi_dual_csiszar_function(), vi_fit_surrogate_posterior(), vi_jeffreys(), vi_jensen_shannon(), vi_kl_forward(), vi_kl_reverse(), vi_log1p_abs(), vi_modified_gan(), vi_monte_carlo_variational_loss(), vi_pearson(), vi_squared_hellinger(), vi_symmetrized_csiszar_function()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.