tf_switch: tf.switch_case

Description Usage Arguments Value Examples

View source: R/tf-ctrl-flow-wrappers.R

Description

tf.switch_case

Usage

1
2
3
4
5
6
7
tf_switch(
  branch_index,
  ...,
  branch_fns = list(...),
  default = NULL,
  name = "switch_case"
)

Arguments

branch_index

an integer tensor

..., branch_fns

a list of function bodies specified with a ~, optionally supplied with a branch index on the left hand side. See examples

default

A function defined with a ~, or something coercible via 'as.function()“

name

a string, passed on to tf.switch_case()

Value

The result from tf.switch_case()

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
## Not run: 
tf_pow <- tf_function(function(x, pow) {
   tf_switch(pow,
   0 ~ 1,
   1 ~ x,
   2 ~ x * x,
   3 ~ x * x * x,
   default = ~ -1)
})

# can optionally also omit the left hand side int, in which case the order of
# the functions is used.
tf_pow <- function(x, pow) {
  tf_switch(pow,
            ~ 1,
            ~ x,
            ~ x * x,
            ~ x * x * x,
            default = ~ -1)
}

# supply just some of the ints to override the default order
tf_pow <- function(x, pow) {
  tf_switch(pow,
            3 ~ x * x * x,
            2 ~ x * x,
            ~ 1,
            ~ x,
            default = ~ -1)
}

# A slightly less contrived example:
tf_norm <- tf_function(function(x, l) {
  tf_switch(l,
            0 ~ tf$reduce_sum(tf$cast(x != 0, tf$float32)), # L0 norm
            1 ~ tf$reduce_sum(tf$abs(x)),                   # L1 norm
            2 ~ tf$sqrt(tf$reduce_sum(tf$square(x))),       # L2 norm
            default = ~ tf$reduce_max(tf$abs(x)))         # L-infinity norm
})

## End(Not run)

tfautograph documentation built on Sept. 18, 2021, 1:07 a.m.