random_categorical | R Documentation |
This function takes as input logits
, a 2-D input tensor with shape
(batch_size, num_classes). Each row of the input represents a categorical
distribution, with each column index containing the log-probability for a
given class.
The function will output a 2-D tensor with shape (batch_size, num_samples),
where each row contains samples from the corresponding row in logits
.
Each column index contains an independent samples drawn from the input
distribution.
random_categorical(logits, num_samples, dtype = "int32", seed = NULL)
logits |
2-D Tensor with shape (batch_size, num_classes). Each row should define a categorical distribution with the unnormalized log-probabilities for all classes. |
num_samples |
Int, the number of independent samples to draw for each row of the input. This will be the second dimension of the output tensor's shape. |
dtype |
Optional dtype of the output tensor. |
seed |
Optional R integer or instance of
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global |
A 2-D tensor with (batch_size, num_samples).
Other random:
random_beta()
random_binomial()
random_dropout()
random_gamma()
random_integer()
random_normal()
random_seed_generator()
random_shuffle()
random_truncated_normal()
random_uniform()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.