Rate this Page

torch.linalg.tensorsolve#

torch.linalg.tensorsolve(A,B,dims=None,*,out=None)Tensor#

Computes the solutionX to the systemtorch.tensordot(A, X) = B.

Ifm is the product of the firstB.ndim dimensions ofA andn is the product of the rest of the dimensions, this function expectsm andn to be equal.

The returned tensorx satisfiestensordot(A, x, dims=x.ndim) ==B.x has shapeA[B.ndim:].

Ifdims is specified,A will be reshaped as

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

Supports inputs of float, double, cfloat and cdouble dtypes.

See also

torch.linalg.tensorinv() computes the multiplicative inverse oftorch.tensordot().

Parameters
  • A (Tensor) – tensor to solve for. Its shape must satisfyprod(A.shape[:B.ndim]) ==prod(A.shape[B.ndim:]).

  • B (Tensor) – tensor of shapeA.shape[:B.ndim].

  • dims (Tuple[int],optional) – dimensions ofA to be moved.IfNone, no dimensions are moved. Default:None.

Keyword Arguments

out (Tensor,optional) – output tensor. Ignored ifNone. Default:None.

Raises

RuntimeError – if the reshapedA.view(m, m) withm as above is not invertible or the product of the firstind dimensions is not equal to the product of the rest of the dimensions.

Examples:

>>>A=torch.eye(2*3*4).reshape((2*3,4,2,3,4))>>>B=torch.randn(2*3,4)>>>X=torch.linalg.tensorsolve(A,B)>>>X.shapetorch.Size([2, 3, 4])>>>torch.allclose(torch.tensordot(A,X,dims=X.ndim),B)True>>>A=torch.randn(6,4,4,3,2)>>>B=torch.randn(4,3,2)>>>X=torch.linalg.tensorsolve(A,B,dims=(0,2))>>>X.shapetorch.Size([6, 4])>>>A=A.permute(1,3,4,0,2)>>>A.shape[B.ndim:]torch.Size([6, 4])>>>torch.allclose(torch.tensordot(A,X,dims=X.ndim),B,atol=1e-6)True