vi_fit_surrogate_posterior: Fit a surrogate posterior to a target (unnormalized) log...

View source: R/vi-optimization.R

vi_fit_surrogate_posteriorR Documentation

Fit a surrogate posterior to a target (unnormalized) log density

Description

The default behavior constructs and minimizes the negative variational evidence lower bound (ELBO), given by q_samples <- surrogate_posterior$sample(num_draws) elbo_loss <- -tf$reduce_mean(target_log_prob_fn(q_samples) - surrogate_posterior$log_prob(q_samples))

Usage

vi_fit_surrogate_posterior(
  target_log_prob_fn,
  surrogate_posterior,
  optimizer,
  num_steps,
  convergence_criterion = NULL,
  trace_fn = tfp$vi$optimization$`_trace_loss`,
  variational_loss_fn = NULL,
  discrepancy_fn = tfp$vi$kl_reverse,
  sample_size = 1,
  importance_sample_size = 1,
  trainable_variables = NULL,
  jit_compile = NULL,
  seed = NULL,
  name = "fit_surrogate_posterior"
)

Arguments

target_log_prob_fn

function that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample <- surrogate_posterior$sample(sample_size), this will be (in Python) called as target_log_prob_fn(q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].

surrogate_posterior

A tfp$distributions$Distribution instance defining a variational posterior (could be a tfp$distributions$JointDistribution). Crucially, the distribution's log_prob and (if reparameterized) sample methods must directly invoke all ops that generate gradients to the underlying variables. One way to ensure this is to use tfp$util$DeferredTensor to represent any parameters defined as transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation.

optimizer

Optimizer instance to use. This may be a TF1-style tf$train$Optimizer, TF2-style tf$optimizers$Optimizer, or any Python-compatible object that implements optimizer$apply_gradients(grads_and_vars).

num_steps

integer number of steps to run the optimizer.

convergence_criterion

Optional instance of tfp$optimizer$convergence_criteria$ConvergenceCriterion representing a criterion for detecting convergence. If NULL, the optimization will run for num_steps steps, otherwise, it will run for at most num_steps steps, as determined by the provided criterion. Default value: NULL.

trace_fn

function with signature state = trace_fn(loss, grads, variables), where state may be a Tensor or nested structure of Tensors. The state values are accumulated (by tf$scan) and returned. The default trace_fn simply returns the loss, but in general can depend on the gradients and variables (if trainable_variables is not NULL then variables==trainable_variables; otherwise it is the list of all variables accessed during execution of loss_fn()), as well as any other quantities captured in the closure of trace_fn, for example, statistics of a variational distribution. Default value: function(loss, grads, variables) loss.

variational_loss_fn

function with signature loss <- variational_loss_fn(target_log_prob_fn, surrogate_posterior, sample_size, seed) defining a variational loss function. The default is a Monte Carlo approximation to the standard evidence lower bound (ELBO), equivalent to minimizing the 'reverse' KL[q||p] divergence between the surrogate q and true posterior p. Default value: functools.partial(tfp.vi.monte_carlo_variational_loss, discrepancy_fn=tfp.vi.kl_reverse, use_reparameterization=True).

discrepancy_fn

A function of Python callable representing a Csiszar f function in log-space. See the docs for tfp.vi.monte_carlo_variational_loss for examples. This argument is ignored if a variational_loss_fn is explicitly specified. Default value: tfp$vi$kl_reverse.

sample_size

integer number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.

importance_sample_size

An integer number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case, posterior expectations should be approximated by importance sampling, as demonstrated in the example below. This argument is ignored if a variational_loss_fn is explicitly specified. Default value: 1.

trainable_variables

Optional list of tf$Variable instances to optimize with respect to. If NULL, defaults to the set of all variables accessed during the computation of the variational bound, i.e., those defining surrogate_posterior and the model target_log_prob_fn. Default value: NULL.

jit_compile

If TRUE, compiles the loss function and gradient update using XLA. XLA performs compiler optimizations, such as fusion, and attempts to emit more efficient code. This may drastically improve the performance. See the docs for tf.function. Default value: NULL.

seed

integer to seed the random number generator.

name

name prefixed to ops created by this function. Default value: 'fit_surrogate_posterior'.

Details

This corresponds to minimizing the 'reverse' Kullback-Liebler divergence (KL[q||p]) between the variational distribution and the unnormalized target_log_prob_fn, and defines a lower bound on the marginal log likelihood, log p(x) >= -elbo_loss.

More generally, this function supports fitting variational distributions that minimize any Csiszar f-divergence.

Value

results Tensor or nested structure of Tensors, according to the return type of result_fn. Each Tensor has an added leading dimension of size num_steps, packing the trajectory of the result over the course of the optimization.

See Also

Other vi-functions: vi_amari_alpha(), vi_arithmetic_geometric(), vi_chi_square(), vi_csiszar_vimco(), vi_dual_csiszar_function(), vi_jeffreys(), vi_jensen_shannon(), vi_kl_forward(), vi_kl_reverse(), vi_log1p_abs(), vi_modified_gan(), vi_monte_carlo_variational_loss(), vi_pearson(), vi_squared_hellinger(), vi_symmetrized_csiszar_function()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.