random_categorical | R Documentation |
This function takes as input logits
, a 2-D input tensor with shape
(batch_size, num_classes). Each row of the input represents a categorical
distribution, with each column index containing the log-probability for a
given class.
The function will output a 2-D tensor with shape (batch_size, num_samples),
where each row contains samples from the corresponding row in logits
.
Each column index contains an independent samples drawn from the input
distribution.
x <- matrix(c(100, .1, 99), nrow = 1) random_categorical(x, num_samples = 5, seed = 1234)
## tf.Tensor([[3 1 1 3 3]], shape=(1, 5), dtype=int32)
random_categorical(x, num_samples = 5, seed = 1234, zero_indexed = TRUE)
## tf.Tensor([[2 0 0 2 2]], shape=(1, 5), dtype=int32)
op_take(x, random_categorical(x, num_samples = 5, seed = 1234))
## tf.Tensor([[ 99. 100. 100. 99. 99.]], shape=(1, 5), dtype=float64)
op_take(x, random_categorical(x, num_samples = 5, seed = 1234, zero_indexed = TRUE), zero_indexed = TRUE)
## tf.Tensor([[ 99. 100. 100. 99. 99.]], shape=(1, 5), dtype=float64)
random_categorical(
logits,
num_samples,
dtype = "int32",
seed = NULL,
zero_indexed = FALSE
)
logits |
2-D Tensor with shape (batch_size, num_classes). Each row should define a categorical distribution with the unnormalized log-probabilities for all classes. |
num_samples |
Int, the number of independent samples to draw for each row of the input. This will be the second dimension of the output tensor's shape. |
dtype |
Optional dtype of the output tensor. |
seed |
Optional R integer or instance of
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global |
zero_indexed |
If |
A 2-D tensor with (batch_size, num_samples).
Other random:
random_beta()
random_binomial()
random_dropout()
random_gamma()
random_integer()
random_normal()
random_seed_generator()
random_shuffle()
random_truncated_normal()
random_uniform()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.