tf_map: 'tf.map_fn()'

Description Usage Arguments Value

View source: R/tf_map.R

Description

Thin wrapper around tf.map_fn() with the following differences:

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
tf_map(
  elems,
  fn,
  dtype = NULL,
  parallel_iterations = NULL,
  back_prop = TRUE,
  swap_memory = FALSE,
  infer_shape = TRUE,
  name = NULL
)

Arguments

elems

A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to fn.

fn

An R function, specified using purrr style ~ syntax, a character string, a python function (or more generally, any python object with a __call__ method) or anything coercible via as.function(). The function will be be called with one argument, which will have the same (possibly nested) structure as elems. Its output must return the same structure as dtype if one is provided, otherwise it must return the same structure as elems.

dtype

(optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn.

parallel_iterations

(optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1.

back_prop

(optional) True enables support for back propagation.

swap_memory

(optional) True enables GPU-CPU memory swapping.

infer_shape

(optional) False disables tests for consistent output shapes.

name

(optional) Name prefix for the returned tensors.

Value

A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying fn to tensors unpacked from elems along the first dimension, from first to last.


tfautograph documentation built on Sept. 18, 2021, 1:07 a.m.