op_scan: Scan a function over leading array axes while carrying along...

op_scanR Documentation

Scan a function over leading array axes while carrying along state.

Description

When the type of xs is an array type or NULL, and the type of ys is an array type, the semantics of op_scan() are given roughly by this implementation:

op_scan <- function(f, init, xs = NULL, length = NULL) {
  xs <- xs %||% vector("list", length)
  if(!is.list(xs))
    xs <- op_unstack(xs)
  ys <- vector("list", length(xs))
  carry <- init
  for (i in seq_along(xs)) {
    c(carry, y) %<-% f(carry, xs[[i]])
    ys[[i]] <- y
  }
  list(carry, op_stack(ys))
}

The loop-carried value carry (init) must hold a fixed shape and dtype across all iterations.

In TensorFlow, y must match carry in shape and dtype. This is not required in other backends.

Usage

op_scan(f, init, xs = NULL, length = NULL, reverse = FALSE, unroll = 1L)

Arguments

f

Callable defines the logic for each loop iteration. This accepts two arguments where the first is a value of the loop carry and the second is a slice of xs along its leading axis. This callable returns a pair where the first represents a new value for the loop carry and the second represents a slice of the output.

init

The initial loop carry value. This can be a scalar, tensor, or any nested structure. It must match the structure of the first element returned by f.

xs

Optional value to scan along its leading axis. This can be a tensor or any nested structure. If xs is not provided, you must specify length to define the number of loop iterations. Defaults to NULL.

length

Optional integer specifying the number of loop iterations. If length is not provided, it defaults to the sizes of leading axis of the arrays in xs. Defaults to NULL.

reverse

Optional boolean specifying whether to run the scan iteration forward or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

unroll

Optional positive integer or boolean specifying how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is completely unrolled (unroll=TRUE) or left completely unrolled (unroll=FALSE). Note that unrolling is only supported by JAX and TensorFlow backends.

Value

A pair where the first element represents the final loop carry value and the second element represents the stacked outputs of f when scanned over the leading axis of the inputs.

Examples

sum_fn <- function(c, x) list(c + x, c + x)
init <- op_array(0L)
xs <- op_array(1:5)
c(carry, result) %<-% op_scan(sum_fn, init, xs)
carry
## tf.Tensor(15, shape=(), dtype=int32)

result
## tf.Tensor([ 1  3  6 10 15], shape=(5), dtype=int32)

See Also

Other core ops:
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scatter()
op_scatter_update()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_greater()
op_greater_equal()
op_hard_sigmoid()
op_hard_silu()
op_hstack()
op_identity()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_roll()
op_round()
op_rsqrt()
op_scatter()
op_scatter_update()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_split()
op_sqrt()
op_square()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tensordot()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()


rstudio/keras documentation built on July 8, 2024, 3:07 p.m.