jax.numpy.tril
Contents
jax.numpy.tril#
- jax.numpy.tril(m,k=0)[source]#
Return lower triangle of an array.
JAX implementation of
numpy.tril()- Parameters:
m (ArrayLike) – input array. Must have
m.ndim>=2.k (int) – k: optional, int, default=0. Specifies the sub-diagonal above which theelements of the array are set to zero.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refersto sub-diagonal above the main diagonal.
- Returns:
An array with same shape as input containing the lower triangle of the givenarray with elements above the sub-diagonal specified by
kare set tozero.- Return type:
See also
jax.numpy.triu(): Returns an upper triangle of an array.jax.numpy.tri(): Returns an array with ones on and below thediagonal and zeros elsewhere.
Examples
>>>x=jnp.array([[1,2,3,4],...[5,6,7,8],...[9,10,11,12]])>>>jnp.tril(x)Array([[ 1, 0, 0, 0], [ 5, 6, 0, 0], [ 9, 10, 11, 0]], dtype=int32)>>>jnp.tril(x,k=1)Array([[ 1, 2, 0, 0], [ 5, 6, 7, 0], [ 9, 10, 11, 12]], dtype=int32)>>>jnp.tril(x,k=-1)Array([[ 0, 0, 0, 0], [ 5, 0, 0, 0], [ 9, 10, 0, 0]], dtype=int32)
When
m.ndim>2,jnp.triloperates batch-wise on the trailing axes.>>>x1=jnp.array([[[1,2],...[3,4]],...[[5,6],...[7,8]]])>>>jnp.tril(x1)Array([[[1, 0], [3, 4]], [[5, 0], [7, 8]]], dtype=int32)
Contents
