gauss_cat_loss: A 'torch::nn_module()' Representing a 'gauss_cat_loss'

View source: R/approach_vaeac_torch_modules.R

gauss_cat_lossR Documentation

A torch::nn_module() Representing a gauss_cat_loss

Description

The ⁠gauss_cat_loss module⁠ layer computes the log probability of the groundtruth for each object given the mask and the distribution parameters. That is, the log-likelihoods of the true/full training observations based on the generative distributions parameters distr_params inferred by the masked versions of the observations.

Usage

gauss_cat_loss(one_hot_max_sizes, min_sigma = 1e-04, min_prob = 1e-04)

Arguments

one_hot_max_sizes

A torch tensor of dimension n_features containing the one hot sizes of the n_features features. That is, if the ith feature is a categorical feature with 5 levels, then one_hot_max_sizes[i] = 5. While the size for continuous features can either be 0 or 1.

min_sigma

For stability it might be desirable that the minimal sigma is not too close to zero.

min_prob

For stability it might be desirable that the minimal probability is not too close to zero.

Details

Note that the module works with mixed data represented as 2-dimensional inputs and it works correctly with missing values in groundtruth as long as they are represented by NaNs.

Author(s)

Lars Henry Berge Olsen


NorskRegnesentral/shapr documentation built on April 19, 2024, 1:19 p.m.