Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.tril

Contents

jax.numpy.tril#

jax.numpy.tril(m,k=0)[source]#

Return lower triangle of an array.

JAX implementation ofnumpy.tril()

Parameters:
  • m (ArrayLike) – input array. Must havem.ndim>=2.

  • k (int) – k: optional, int, default=0. Specifies the sub-diagonal above which theelements of the array are set to zero.k=0 refers to main diagonal,k<0 refers to sub-diagonal below the main diagonal andk>0 refersto sub-diagonal above the main diagonal.

Returns:

An array with same shape as input containing the lower triangle of the givenarray with elements above the sub-diagonal specified byk are set tozero.

Return type:

Array

See also

Examples

>>>x=jnp.array([[1,2,3,4],...[5,6,7,8],...[9,10,11,12]])>>>jnp.tril(x)Array([[ 1,  0,  0,  0],       [ 5,  6,  0,  0],       [ 9, 10, 11,  0]], dtype=int32)>>>jnp.tril(x,k=1)Array([[ 1,  2,  0,  0],       [ 5,  6,  7,  0],       [ 9, 10, 11, 12]], dtype=int32)>>>jnp.tril(x,k=-1)Array([[ 0,  0,  0,  0],       [ 5,  0,  0,  0],       [ 9, 10,  0,  0]], dtype=int32)

Whenm.ndim>2,jnp.tril operates batch-wise on the trailing axes.

>>>x1=jnp.array([[[1,2],...[3,4]],...[[5,6],...[7,8]]])>>>jnp.tril(x1)Array([[[1, 0],        [3, 4]],       [[5, 0],        [7, 8]]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp