op_custom_gradient | R Documentation |
This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
op_custom_gradient(f)
f |
Function
|
A function h(...)
which returns the same value as f(...)[[1]]
and whose
gradient is determined by f(...)[[2]]
.
Backend-agnostic example.
log1pexp <- op_custom_gradient(\(x) { e <- op_exp(x) grad <- function(..., upstream = NULL) { upstream <- upstream %||% ..1 op_multiply(upstream, 1.0 - 1.0 / op_add(1, e)) } tuple(op_log(1 + e), grad) }) if(config_backend() == "tensorflow") { tf <- tensorflow::tf x <- op_convert_to_tensor(100.0) with(tf$GradientTape() %as% tape, { tape$watch(x) y <- log1pexp(x) }) dy_dx <- tape$gradient(y, x) stopifnot(as.numeric(dy_dx) == 1) }
Note that the grad
function that returns gradient computation
requires ...
as well as an upstream
named argument, depending
on the backend being set. With the JAX and TensorFlow backends,
it requires only one argument, whereas it might use the upstream
argument in the case of the PyTorch backend.
When working with TensorFlow/JAX backend, grad(upstream)
is sufficient. With PyTorch, the grad
function requires
...
as well as upstream
, e.g. grad <- \(..., upstream)
.
Follow the example above to use op_custom_gradient()
in
a way that is compatible with all backends.
Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()
Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.