vaeac_get_val_iwae: Compute the Importance Sampling Estimator (Validation Error)

View source: R/approach_vaeac_torch_modules.R

vaeac_get_val_iwaeR Documentation

Compute the Importance Sampling Estimator (Validation Error)

Description

Compute the Importance Sampling Estimator which the vaeac model uses to evaluate its performance on the validation data.

Usage

vaeac_get_val_iwae(
  val_dataloader,
  mask_generator,
  batch_size,
  vaeac_model,
  val_iwae_n_samples
)

Arguments

val_dataloader

A torch dataloader which loads the validation data.

mask_generator

A mask generator object that generates the masks.

batch_size

Integer. The number of samples to include in each batch.

vaeac_model

The vaeac model.

val_iwae_n_samples

Number of samples to generate for computing the IWAE for each validation sample.

Details

Compute mean IWAE log likelihood estimation of the validation set. IWAE is an abbreviation for Importance Sampling Estimator

\log p_{\theta, \psi}(x|y) \approx \log {\frac{1}{S}\sum_{i=1}^S p_\theta(x|z_i, y) p_\psi(z_i|y) \big/ q_\phi(z_i|x,y),}

where z_i \sim q_\phi(z|x,y). For more details, see Olsen et al. (2022).

Value

The average iwae over all instances in the validation dataset.

Author(s)

Lars Henry Berge Olsen


NorskRegnesentral/shapr documentation built on April 19, 2024, 1:19 p.m.