Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.triu

Contents

jax.numpy.triu#

jax.numpy.triu(m,k=0)[source]#

Return upper triangle of an array.

JAX implementation ofnumpy.triu()

Parameters:
  • m (ArrayLike) – input array. Must havem.ndim>=2.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal below which theelements of the array are set to zero.k=0 refers to main diagonal,k<0 refers to sub-diagonal below the main diagonal andk>0 refersto sub-diagonal above the main diagonal.

Returns:

An array with same shape as input containing the upper triangle of the givenarray with elements below the sub-diagonal specified byk are set tozero.

Return type:

Array

See also

Examples

>>>x=jnp.array([[1,2,3],...[4,5,6],...[7,8,9],...[10,11,12]])>>>jnp.triu(x)Array([[1, 2, 3],       [0, 5, 6],       [0, 0, 9],       [0, 0, 0]], dtype=int32)>>>jnp.triu(x,k=1)Array([[0, 2, 3],       [0, 0, 6],       [0, 0, 0],       [0, 0, 0]], dtype=int32)>>>jnp.triu(x,k=-1)Array([[ 1,  2,  3],       [ 4,  5,  6],       [ 0,  8,  9],       [ 0,  0, 12]], dtype=int32)

Whenm.ndim>2,jnp.triu operates batch-wise on the trailing axes.

>>>x1=jnp.array([[[1,2],...[3,4]],...[[5,6],...[7,8]]])>>>jnp.triu(x1)Array([[[1, 2],        [0, 4]],       [[5, 6],        [0, 8]]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp