layer_dense_variational: Dense Variational Layer

View source: R/layers.R

layer_dense_variationalR Documentation

Dense Variational Layer

Description

This layer uses variational inference to fit a "surrogate" posterior to the distribution over both the kernel matrix and the bias terms which are otherwise used in a manner similar to layer_dense(). This layer fits the "weights posterior" according to the following generative process:

[K, b] ~ Prior()
M = matmul(X, K) + b
Y ~ Likelihood(M)

Usage

layer_dense_variational(
  object,
  units,
  make_posterior_fn,
  make_prior_fn,
  kl_weight = NULL,
  kl_use_exact = FALSE,
  activation = NULL,
  use_bias = TRUE,
  ...
)

Arguments

object

What to compose the new Layer instance with. Typically a Sequential model or a Tensor (e.g., as returned by layer_input()). The return value depends on object. If object is:

  • missing or NULL, the Layer instance is returned.

  • a Sequential model, the model with an additional layer is returned.

  • a Tensor, the output tensor from layer_instance(object) is returned.

units

Positive integer, dimensionality of the output space.

make_posterior_fn

function taking tf$size(kernel), tf$size(bias), dtype and returns another callable which takes an input and produces a tfd$Distribution instance.

make_prior_fn

function taking tf$size(kernel), tf$size(bias), dtype and returns another callable which takes an input and produces a tfd$Distribution instance.

kl_weight

Amount by which to scale the KL divergence loss between prior and posterior.

kl_use_exact

Logical indicating that the analytical KL divergence should be used rather than a Monte Carlo approximation.

activation

An activation function. See keras::layer_dense. Default: NULL.

use_bias

Whether or not the dense layers constructed in this layer should have a bias term. See keras::layer_dense. Default: TRUE.

...

Additional keyword arguments passed to the keras::layer_dense constructed by this layer.

Value

a Keras layer

See Also

Other layers: layer_autoregressive(), layer_conv_1d_flipout(), layer_conv_1d_reparameterization(), layer_conv_2d_flipout(), layer_conv_2d_reparameterization(), layer_conv_3d_flipout(), layer_conv_3d_reparameterization(), layer_dense_flipout(), layer_dense_local_reparameterization(), layer_dense_reparameterization(), layer_variable()


tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.