ag_if_vars: Specify 'tf.cond()' output structure when autographing 'if'

Description Usage Arguments Details Value Examples

View source: R/hints.R

Description

This function can be used to specify the output structure from tf.cond() when autographing an if statement. In most use cases, use of this function is purely optional. If not supplied, the if output structure is automatically built.

Usage

1
2
3
4
5
6
7
ag_if_vars(
  ...,
  modified = list(),
  return = FALSE,
  undefs = NULL,
  control_flow = 0
)

Arguments

...

Variables modified by the tf.cond() node supplied as bare symbols like foo or expressions using $ e.g, foo$bar. Symbols do not have to exist before the autographed if so long as they are created in both branches.

modified

Variables names supplied as a character vector, or a list of character vectors if specifying nested complex structures. This is an escape hatch for the lazy evaluation semantics of ...

return

logical, whether to include the return value the evaluated R expression in the tf.cond(). if FALSE (the default), only the objects assigned in scope are captured.

undefs

A bare character vector or a list of character vectors. Supplied names are exported as undefs in the parent frame. This is used to give a more informative error message when attempting to access a variable that can't be balanced between branches.

control_flow

An integer, the maximum number of control-flow statements (break and/or next) that will be captured in a single branch as part of the tf.cond(). Do not count statements in loops that are dispatching to standard R control flow (e.g., don't count break statements in a for loop that is iterating over an R vector)

Details

If the output structure is not explicitly supplied via ag_if_vars(), then the output structure is automatically composed: The true and false branches of the expression are traced into concrete functions, then the output signature from the two branch functions are balanced. Balancing is performed by either fetching a variable from an outer scope or by reclassifying a symbol as an undef.

When dealing with complex composites (that is, nested structures where a modified tensor is part of a named list or dictionary), care is taken to prevent unnecessarily capturing other unmodified tensors in the structure. This is done by pruning unmodified tensors from the returned output structure, and then merging them back with the original object recursively. One limitation of the implementation is that lists must either be fully named with unique names, or not named at all, partially named lists or duplicated names in a list throw an error. This is due to the conversion that happens when going between python and R: named lists get converted to python dictionaries, which require that all keys are unique. Additionally, pruning of unmodified objects from an autographed if is currently only supported for named lists (python dictionaries). Unnamed lists or tuples are passed as is (e.g, no pruning and merging done), which may lead to unnecessarily bloat in the constructed graphs.

Value

NULL, invisibly

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
## Not run: 
# these examples only have an effect in graph mode
# to enter graph mode easily we'll create a few helpers
ag <- autograph

# pass which symbols you expect to be modifed or created liks this:
ag_if_vars(x)
ag(if (y > 0) {
  x <- y * y
} else {
  x <- y
})

# if the return value from the if expression is important, pass `return = TRUE`
ag_if_vars(return = TRUE)
x <- ag(if(y > 0) y * y else y)

# pass complex nested structures like this
x <- list(a = 1, b = 2)

ag_if_vars(x$a)
ag(if(y > 0) {
  x$a <- y
})

# undefs are for mark branch-local variables
ag_if_vars(y, x$a, undef = "tmp_local_var")
ag(if(y > 0) {
  y <- y * 100
  tmp_local_var <- y + 1
  x$a <- tmp_local_var
})

# supplying `undef` is not necessary, it exists purely as a way to supply a
# guardrail for defensive programming and/or to improve code readability

## modified vars can be supplied in `...` or as a named arg.
## these paires of ag_if_vars() calls are equivalent
ag_if_vars(y, x$a)
ag_if_vars(modified = list("y", c("x", "a")))

ag_if_vars(x, y, z)
ag_if_vars(modified = c("x", "y", "z"))


## control flow
# count number of odds between 0:10
ag({
  x <- 10
  count <- 0
  while(x > 0) {
    ag_if_vars(control_flow = 1)
    if(x %% 2 == 0)
      next
    count <- count + 1
  }
})

## End(Not run)

tfautograph documentation built on Sept. 18, 2021, 1:07 a.m.