jax.numpy.linalg.qr
Contents
jax.numpy.linalg.qr#
- jax.numpy.linalg.qr(a,mode='reduced')[source]#
Compute the QR decomposition of an array
JAX implementation of
numpy.linalg.qr().The QR decomposition of a matrixA is given by
\[A = QR\]WhereQ is a unitary matrix (i.e.\(Q^HQ=I\)) andR is an upper-triangularmatrix.
- Parameters:
a (ArrayLike) – array of shape (…, M, N)
mode (str) –
Computational mode. Supported values are:
"reduced"(default): returnQ of shape(...,M,K)andR of shape(...,K,N), whereK=min(M,N)."complete": returnQ of shape(...,M,M)andR of shape(...,M,N)."raw": return lapack-internal representations of shape(...,M,N)and(...,K)."r": returnR only.
- Returns:
A tuple
(Q,R)(ifmodeis not"r") otherwise an arrayR,where:Qis an orthogonal matrix of shape(...,M,K)(ifmodeis"reduced")or(...,M,M)(ifmodeis"complete").Ris an upper-triangular matrix of shape(...,M,N)(ifmodeis"r"or"complete") or(...,K,N)(ifmodeis"reduced")
with
K=min(M,N).- Return type:
Array | QRResult
See also
jax.scipy.linalg.qr(): SciPy-style QR decomposition APIjax.lax.linalg.qr(): XLA-style QR decomposition API
Examples
Compute the QR decomposition of a matrix:
>>>a=jnp.array([[1.,2.,3.,4.],...[5.,4.,2.,1.],...[6.,3.,1.,5.]])>>>Q,R=jnp.linalg.qr(a)>>>QArray([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32)>>>RArray([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that
Qis orthonormal:>>>jnp.allclose(Q.T@Q,jnp.eye(3),atol=1E-5)Array(True, dtype=bool)
Reconstruct the input:
>>>jnp.allclose(Q@R,a)Array(True, dtype=bool)
