autograd_backward: Computes the sum of gradients of given tensors w.r.t. graph...

Description Usage Arguments Details Examples

View source: R/autograd.R

Description

The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, then the Jacobian-vector product would be computed, in this case the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains the “vector” in the Jacobian-vector product, usually the gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors).

Usage

1
2
3
4
5
6
autograd_backward(
  tensors,
  grad_tensors = NULL,
  retain_graph = create_graph,
  create_graph = FALSE
)

Arguments

tensors

(list of Tensor) – Tensors of which the derivative will be computed.

grad_tensors

(list of (Tensor or NULL)) – The “vector” in the Jacobian-vector product, usually gradients w.r.t. each element of corresponding tensors. NULLvalues can be specified for scalar Tensors or ones that don’t require grad. If aNULL' value would be acceptable for all grad_tensors, then this argument is optional.

retain_graph

(bool, optional) – If FALSE, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to TRUE is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

create_graph

(bool, optional) – If TRUE, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to FALSE.

Details

This function accumulates gradients in the leaves - you might need to zero them before calling it.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
if (torch_is_installed()) {
x <- torch_tensor(1, requires_grad = TRUE)
y <- 2 * x

a <- torch_tensor(1, requires_grad = TRUE)
b <- 3 * a

autograd_backward(list(y, b))

}

torch documentation built on Oct. 7, 2021, 9:22 a.m.