torch.linalg.solve#
- torch.linalg.solve(A,B,*,left=True,out=None)→Tensor#
Computes the solution of a square system of linear equations with a unique solution.
Letting be or,this function computes the solution of thelinear system associated to, which is defined as
If
left= False, this function returns the matrix that solves the systemThis system of linear equations has one solution if and only if isinvertible.This function assumes that is invertible.
Supports inputs of float, double, cfloat and cdouble dtypes.Also supports batches of matrices, and if the inputs are batches of matrices thenthe output has the same batch dimensions.
Letting* be zero or more batch dimensions,
If
Ahas shape(*, n, n) andBhas shape(*, n) (a batch of vectors) or shape(*, n, k) (a batch of matrices or “multiple right-hand sides”), this function returnsX of shape(*, n) or(*, n, k) respectively.Otherwise, if
Ahas shape(*, n, n) andBhas shape(n,) or(n, k),Bis broadcasted to have shape(*, n) or(*, n, k) respectively.This function then returns the solution of the resulting batch of systems of linear equations.
Note
This function computesX =
A.inverse() @Bin a faster andmore numerically stable way than performing the computations separately.Note
It is possible to compute the solution of the system by passing the inputs
AandBtransposed and transposing the output returned by this function.Note
Ais allowed to be a non-batchedtorch.sparse_csr_tensor, but only withleft=True.Note
When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see
torch.linalg.solve_ex().See also
torch.linalg.solve_triangular()computes the solution of a triangular system of linearequations with a unique solution.- Parameters
- Keyword Arguments
- Raises
RuntimeError – if the
Amatrix is not invertible or any matrix in a batchedAis not invertible.
Examples:
>>>A=torch.randn(3,3)>>>b=torch.randn(3)>>>x=torch.linalg.solve(A,b)>>>torch.allclose(A@x,b)True>>>A=torch.randn(2,3,3)>>>B=torch.randn(2,3,4)>>>X=torch.linalg.solve(A,B)>>>X.shapetorch.Size([2, 3, 4])>>>torch.allclose(A@X,B)True>>>A=torch.randn(2,3,3)>>>b=torch.randn(3,1)>>>x=torch.linalg.solve(A,b)# b is broadcasted to size (2, 3, 1)>>>x.shapetorch.Size([2, 3, 1])>>>torch.allclose(A@x,b)True>>>b=torch.randn(3)>>>x=torch.linalg.solve(A,b)# b is broadcasted to size (2, 3)>>>x.shapetorch.Size([2, 3])>>>Ax=A@x.unsqueeze(-1)>>>torch.allclose(Ax,b.unsqueeze(-1).expand_as(Ax))True