gauss_cat_loss: A 'torch::nn_module()' Representing a 'gauss_cat_loss'

View source: R/approach_vaeac_torch_modules.R

gauss_cat_lossR Documentation

A torch::nn_module() Representing a gauss_cat_loss

Description

The ⁠gauss_cat_loss module⁠ layer computes the log probability of the groundtruth for each object given the mask and the distribution parameters. That is, the log-likelihoods of the true/full training observations based on the generative distributions parameters distr_params inferred by the masked versions of the observations.

Usage

gauss_cat_loss(one_hot_max_sizes, min_sigma = 1e-04, min_prob = 1e-04)

Arguments

one_hot_max_sizes

A torch tensor of dimension n_features containing the one hot sizes of the n_features features. That is, if the ith feature is a categorical feature with 5 levels, then one_hot_max_sizes[i] = 5. While the size for continuous features can either be 0 or 1.

min_sigma

For stability it might be desirable that the minimal sigma is not too close to zero.

min_prob

For stability it might be desirable that the minimal probability is not too close to zero.

Details

Note that the module works with mixed data represented as 2-dimensional inputs and it works correctly with missing values in groundtruth as long as they are represented by NaNs.

Author(s)

Lars Henry Berge Olsen


shapr documentation built on April 4, 2025, 12:18 a.m.