View source: R/gen-namespace.R
torch_gather | R Documentation |
Gather
torch_gather(self, dim, index, sparse_grad = FALSE)
self |
(Tensor) the source tensor |
dim |
(int) the axis along which to index |
index |
(LongTensor) the indices of elements to gather |
sparse_grad |
(bool,optional) If |
Gathers values along an axis specified by dim
.
For a 3-D tensor the output is specified by::
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If input
is an n-dimensional tensor with size
(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})
and dim = i
, then index
must be an n
-dimensional tensor with
size (x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})
where y \geq 1
and out
will have the same size as index
.
if (torch_is_installed()) {
t = torch_tensor(matrix(c(1,2,3,4), ncol = 2, byrow = TRUE))
torch_gather(t, 2, torch_tensor(matrix(c(1,1,2,1), ncol = 2, byrow=TRUE), dtype = torch_int64()))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.