vae_loss_correlated: A custom loss function for a VAE learning a multivariate...

Description Usage Arguments Value

View source: R/loss_functions.R

Description

A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix

Usage

1
2
3
4
5
6
7
8
vae_loss_correlated(
  encoder,
  inv_skill_cov,
  det_skill_cov,
  skill_mean,
  kl_weight,
  rec_dim
)

Arguments

encoder

the encoder model of the VAE, used to obtain z_mean and z_log_cholesky from inputs

inv_skill_cov

a constant tensor matrix of the inverse of the covariance matrix being learned

det_skill_cov

a constant tensor scalar representing the determinant of the covariance matrix being learned

skill_mean

a constant tensor vector representing the means of the latent skills being learned

kl_weight

weight for the KL divergence term

rec_dim

the number of nodes in the input/output of the VAE

Value

returns a function whose parameters match keras loss format


converseg/ML2Pvae documentation built on April 6, 2021, 1:46 a.m.