linalg_tensorsolve: Computes the solution 'X' to the system 'torch_tensordot(A,...

View source: R/linalg.R

linalg_tensorsolveR Documentation

Computes the solution X to the system torch_tensordot(A, X) = B.

Description

If m is the product of the first B\ .ndim dimensions of A and n is the product of the rest of the dimensions, this function expects m and n to be equal. The returned tensor x satisfies tensordot(A, x, dims=x$ndim) == B.

Usage

linalg_tensorsolve(A, B, dims = NULL)

Arguments

A

(Tensor): tensor to solve for.

B

(Tensor): the solution

dims

(Tupleint, optional): dimensions of A to be moved. If NULL, no dimensions are moved. Default: NULL.

Details

If dims is specified, A will be reshaped as A = movedim(A, dims, seq(len(dims) - A$ndim + 1, 0))

Supports inputs of float, double, cfloat and cdouble dtypes.

See Also

  • linalg_tensorinv() computes the multiplicative inverse of torch_tensordot().

Other linalg: linalg_cholesky_ex(), linalg_cholesky(), linalg_det(), linalg_eigh(), linalg_eigvalsh(), linalg_eigvals(), linalg_eig(), linalg_householder_product(), linalg_inv_ex(), linalg_inv(), linalg_lstsq(), linalg_matrix_norm(), linalg_matrix_power(), linalg_matrix_rank(), linalg_multi_dot(), linalg_norm(), linalg_pinv(), linalg_qr(), linalg_slogdet(), linalg_solve(), linalg_svdvals(), linalg_svd(), linalg_tensorinv(), linalg_vector_norm()

Examples

if (torch_is_installed()) {
A <- torch_eye(2 * 3 * 4)$reshape(c(2 * 3, 4, 2, 3, 4))
B <- torch_randn(2 * 3, 4)
X <- linalg_tensorsolve(A, B)
X$shape
torch_allclose(torch_tensordot(A, X, dims = X$ndim), B)

A <- torch_randn(6, 4, 4, 3, 2)
B <- torch_randn(4, 3, 2)
X <- linalg_tensorsolve(A, B, dims = c(1, 3))
A <- A$permute(c(2, 4, 5, 1, 3))
torch_allclose(torch_tensordot(A, X, dims = X$ndim), B, atol = 1e-6)
}

torch documentation built on June 7, 2023, 6:19 p.m.