jax.numpy.linalg.inv
Contents
jax.numpy.linalg.inv#
- jax.numpy.linalg.inv(a)[source]#
Return the inverse of a square matrix
JAX implementation of
numpy.linalg.inv().- Parameters:
a (ArrayLike) – array of shape
(...,N,N)specifying square array(s) to be inverted.- Returns:
Array of shape
(...,N,N)containing the inverse of the input.- Return type:
Notes
In most cases, explicitly computing the inverse of a matrix is ill-advised. Forexample, to compute
x=inv(A)@b, it is more performant and numericallyprecise to use a direct solve, such asjax.scipy.linalg.solve().See also
jax.scipy.linalg.inv(): SciPy-style API for matrix inversejax.numpy.linalg.solve(): direct linear solver
Examples
Compute the inverse of a 3x3 matrix
>>>a=jnp.array([[1.,2.,3.],...[2.,4.,2.],...[3.,2.,1.]])>>>a_inv=jnp.linalg.inv(a)>>>a_invArray([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
Check that multiplying with the inverse gives the identity:
>>>jnp.allclose(a@a_inv,jnp.eye(3),atol=1E-5)Array(True, dtype=bool)
Multiply the inverse by a vector
b, to find a solution toa@x=b>>>b=jnp.array([1.,4.,2.])>>>a_inv@bArray([ 0. , 1.25, -0.5 ], dtype=float32)
Note, however, that explicitly computing the inverse in such a case can leadto poor performance and loss of precision as the size of the problem grows.Instead, you should use a direct solver like
jax.numpy.linalg.solve():>>>jnp.linalg.solve(a,b) Array([ 0. , 1.25, -0.5 ], dtype=float32)
