jax.numpy.linalg.solve
Contents
jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a,b)[source]#
Solve a linear system of equations.
JAX implementation of
numpy.linalg.solve().This solves a (batched) linear system of equations
a@x=bforxgivenaandb.If
ais singular, this will returnnanorinfvalues.- Parameters:
a (ArrayLike) – array of shape
(...,N,N).b (ArrayLike) – array of shape
(N,)(for 1-dimensional right-hand-side) or(...,N,M)(for batched 2-dimensional right-hand-side).
- Returns:
An array containing the result of the linear solve if
ais non-singular.The result has shape(...,N)ifbis of shape(N,), and hasshape(...,N,M)otherwise.Ifais singular, the result containsnanorinfvalues.- Return type:
See also
jax.scipy.linalg.solve(): SciPy-style API for solving linear systems.jax.lax.custom_linear_solve(): matrix-free linear solver.
Examples
A simple 3x3 linear system:
>>>A=jnp.array([[1.,2.,3.],...[2.,4.,2.],...[3.,2.,1.]])>>>b=jnp.array([14.,16.,10.])>>>x=jnp.linalg.solve(A,b)>>>xArray([1., 2., 3.], dtype=float32)
Confirming that the result solves the system:
>>>jnp.allclose(A@x,b)Array(True, dtype=bool)
