View source: R/vi-optimization.R
vi_fit_surrogate_posterior | R Documentation |
The default behavior constructs and minimizes the negative variational
evidence lower bound (ELBO), given by q_samples <- surrogate_posterior$sample(num_draws) elbo_loss <- -tf$reduce_mean(target_log_prob_fn(q_samples) - surrogate_posterior$log_prob(q_samples))
vi_fit_surrogate_posterior( target_log_prob_fn, surrogate_posterior, optimizer, num_steps, convergence_criterion = NULL, trace_fn = tfp$vi$optimization$`_trace_loss`, variational_loss_fn = NULL, discrepancy_fn = tfp$vi$kl_reverse, sample_size = 1, importance_sample_size = 1, trainable_variables = NULL, jit_compile = NULL, seed = NULL, name = "fit_surrogate_posterior" )
target_log_prob_fn |
function that takes a set of |
surrogate_posterior |
A |
optimizer |
Optimizer instance to use. This may be a TF1-style
|
num_steps |
|
convergence_criterion |
Optional instance of
|
trace_fn |
function with signature |
variational_loss_fn |
function with signature |
discrepancy_fn |
A function of Python |
sample_size |
|
importance_sample_size |
An integer number of terms used to define
an importance-weighted divergence. If |
trainable_variables |
Optional list of |
jit_compile |
If |
seed |
integer to seed the random number generator. |
name |
name prefixed to ops created by this function. Default value: 'fit_surrogate_posterior'. |
This corresponds to minimizing the 'reverse' Kullback-Liebler divergence
(KL[q||p]
) between the variational distribution and the unnormalized
target_log_prob_fn
, and defines a lower bound on the marginal log
likelihood, log p(x) >= -elbo_loss
.
More generally, this function supports fitting variational distributions that minimize any Csiszar f-divergence.
results Tensor
or nested structure of Tensor
s, according to
the return type of result_fn
. Each Tensor
has an added leading
dimension of size num_steps
, packing the trajectory of the result
over the course of the optimization.
Other vi-functions:
vi_amari_alpha()
,
vi_arithmetic_geometric()
,
vi_chi_square()
,
vi_csiszar_vimco()
,
vi_dual_csiszar_function()
,
vi_jeffreys()
,
vi_jensen_shannon()
,
vi_kl_forward()
,
vi_kl_reverse()
,
vi_log1p_abs()
,
vi_modified_gan()
,
vi_monte_carlo_variational_loss()
,
vi_pearson()
,
vi_squared_hellinger()
,
vi_symmetrized_csiszar_function()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.