R/samplers.R

Defines functions sample_categorical sample_bernoulli

Documented in sample_bernoulli sample_categorical

#' @title Bernoulli sample
#'
#' @description Samples from Bernoulli distribution.
#'
#'
#' @param probs probabilities
#' @param logits logits
#' @param dtype the data type
#' @param sample_shape a list/vector of integers
#' @param seed integer, random seed
#' @return a Tensor
#' @export
sample_bernoulli <- function(probs = NULL, logits = NULL,
                             dtype = tf$int32,
                             sample_shape = list(),
                             seed = NULL) {

  args <- list(
    probs = probs,
    logits = logits,
    dtype = dtype,
    sample_shape = sample_shape,
    seed = seed
  )

  if(!is.null(seed))
    args$seed <- as.integer(args$seed)

  do.call(tfa$seq2seq$sampler$bernoulli_sample, args)

}

#' @title Categorical sample
#'
#' @description Samples from categorical distribution.
#'
#'
#' @param logits logits
#' @param dtype dtype
#' @param sample_shape the shape of sample
#' @param seed random seed: integer
#' @return a Tensor
#' @export
sample_categorical <- function(logits, dtype = tf$int32, sample_shape = list(), seed = NULL) {

  args <- list(
    logits = logits,
    dtype = dtype,
    sample_shape = sample_shape,
    seed = seed
  )

  if(!is.null(seed))
    args$seed <- as.integer(args$seed)

  do.call(tfa$seq2seq$sampler$categorical_sample, args)

}

Try the tfaddons package in your browser

Any scripts or data that you put into this service are public.

tfaddons documentation built on July 2, 2020, 2:12 a.m.