Rate this Page

torch.lu_solve#

torch.lu_solve(b,LU_data,LU_pivots,*,out=None)Tensor#

Returns the LU solve of the linear systemAx=bAx = b using the partially pivotedLU factorization of A fromlu_factor().

This function supportsfloat,double,cfloat andcdouble dtypes forinput.

Warning

torch.lu_solve() is deprecated in favor oftorch.linalg.lu_solve().torch.lu_solve() will be removed in a future PyTorch release.X=torch.lu_solve(B,LU,pivots) should be replaced with

X=linalg.lu_solve(LU,pivots,B)
Parameters
  • b (Tensor) – the RHS tensor of size(,m,k)(*, m, k), where*is zero or more batch dimensions.

  • LU_data (Tensor) – the pivoted LU factorization of A fromlu_factor() of size(,m,m)(*, m, m),where* is zero or more batch dimensions.

  • LU_pivots (IntTensor) – the pivots of the LU factorization fromlu_factor() of size(,m)(*, m),where* is zero or more batch dimensions.The batch dimensions ofLU_pivots must be equal to the batch dimensions ofLU_data.

Keyword Arguments

out (Tensor,optional) – the output tensor.

Example:

>>>A=torch.randn(2,3,3)>>>b=torch.randn(2,3,1)>>>LU,pivots=torch.linalg.lu_factor(A)>>>x=torch.lu_solve(b,LU,pivots)>>>torch.dist(A@x,b)tensor(1.00000e-07 *       2.8312)