torch.linalg.tensorsolve#
- torch.linalg.tensorsolve(A,B,dims=None,*,out=None)→Tensor#
Computes the solutionX to the systemtorch.tensordot(A, X) = B.
Ifm is the product of the first
B.ndim dimensions ofAandn is the product of the rest of the dimensions, this function expectsm andn to be equal.The returned tensorx satisfiestensordot(
A, x, dims=x.ndim) ==B.x has shapeA[B.ndim:].If
dimsis specified,Awill be reshaped asA = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))
Supports inputs of float, double, cfloat and cdouble dtypes.
See also
torch.linalg.tensorinv()computes the multiplicative inverse oftorch.tensordot().- Parameters
- Keyword Arguments
out (Tensor,optional) – output tensor. Ignored ifNone. Default:None.
- Raises
RuntimeError – if the reshaped
A.view(m, m) withm as above is not invertible or the product of the firstinddimensions is not equal to the product of the rest of the dimensions.
Examples:
>>>A=torch.eye(2*3*4).reshape((2*3,4,2,3,4))>>>B=torch.randn(2*3,4)>>>X=torch.linalg.tensorsolve(A,B)>>>X.shapetorch.Size([2, 3, 4])>>>torch.allclose(torch.tensordot(A,X,dims=X.ndim),B)True>>>A=torch.randn(6,4,4,3,2)>>>B=torch.randn(4,3,2)>>>X=torch.linalg.tensorsolve(A,B,dims=(0,2))>>>X.shapetorch.Size([6, 4])>>>A=A.permute(1,3,4,0,2)>>>A.shape[B.ndim:]torch.Size([6, 4])>>>torch.allclose(torch.tensordot(A,X,dims=X.ndim),B,atol=1e-6)True