jax.numpy.triu
Contents
jax.numpy.triu#
- jax.numpy.triu(m,k=0)[source]#
Return upper triangle of an array.
JAX implementation of
numpy.triu()- Parameters:
m (ArrayLike) – input array. Must have
m.ndim>=2.k (int) – optional, int, default=0. Specifies the sub-diagonal below which theelements of the array are set to zero.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refersto sub-diagonal above the main diagonal.
- Returns:
An array with same shape as input containing the upper triangle of the givenarray with elements below the sub-diagonal specified by
kare set tozero.- Return type:
See also
jax.numpy.tril(): Returns a lower triangle of an array.jax.numpy.tri(): Returns an array with ones on and below thediagonal and zeros elsewhere.
Examples
>>>x=jnp.array([[1,2,3],...[4,5,6],...[7,8,9],...[10,11,12]])>>>jnp.triu(x)Array([[1, 2, 3], [0, 5, 6], [0, 0, 9], [0, 0, 0]], dtype=int32)>>>jnp.triu(x,k=1)Array([[0, 2, 3], [0, 0, 6], [0, 0, 0], [0, 0, 0]], dtype=int32)>>>jnp.triu(x,k=-1)Array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], [ 0, 0, 12]], dtype=int32)
When
m.ndim>2,jnp.triuoperates batch-wise on the trailing axes.>>>x1=jnp.array([[[1,2],...[3,4]],...[[5,6],...[7,8]]])>>>jnp.triu(x1)Array([[[1, 2], [0, 4]], [[5, 6], [0, 8]]], dtype=int32)
Contents
