jax.numpy.tri
Contents
jax.numpy.tri#
- jax.numpy.tri(N,M=None,k=0,dtype=None)[source]#
Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of
numpy.tri()- Parameters:
N (int) – int. Dimension of the rows of the returned array.
M (int |None) – optional, int. Dimension of the columns of the returned array. If notspecified, then
M=N.k (int) – optional, int, default=0. Specifies the sub-diagonal on and below whichthe array is filled with ones.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers tosub-diagonal above the main diagonal.dtype (DTypeLike |None) – optional, data type of the returned array. The default type is float.
- Returns:
An array of shape
(N,M)containing the lower triangle with elementsbelow the sub-diagonal specified bykare set to one and zero elsewhere.- Return type:
See also
jax.numpy.tril(): Returns a lower triangle of an array.jax.numpy.triu(): Returns an upper triangle of an array.
Examples
>>>jnp.tri(3)Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
When
Mis not equal toN:>>>jnp.tri(3,4)Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
when
k>0:>>>jnp.tri(3,k=1)Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
When
k<0:>>>jnp.tri(3,4,k=-1)Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)
