Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.tri

Contents

jax.numpy.tri#

jax.numpy.tri(N,M=None,k=0,dtype=None)[source]#

Return an array with ones on and below the diagonal and zeros elsewhere.

JAX implementation ofnumpy.tri()

Parameters:
  • N (int) – int. Dimension of the rows of the returned array.

  • M (int |None) – optional, int. Dimension of the columns of the returned array. If notspecified, thenM=N.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and below whichthe array is filled with ones.k=0 refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0 refers tosub-diagonal above the main diagonal.

  • dtype (DTypeLike |None) – optional, data type of the returned array. The default type is float.

Returns:

An array of shape(N,M) containing the lower triangle with elementsbelow the sub-diagonal specified byk are set to one and zero elsewhere.

Return type:

Array

See also

Examples

>>>jnp.tri(3)Array([[1., 0., 0.],       [1., 1., 0.],       [1., 1., 1.]], dtype=float32)

WhenM is not equal toN:

>>>jnp.tri(3,4)Array([[1., 0., 0., 0.],       [1., 1., 0., 0.],       [1., 1., 1., 0.]], dtype=float32)

whenk>0:

>>>jnp.tri(3,k=1)Array([[1., 1., 0.],       [1., 1., 1.],       [1., 1., 1.]], dtype=float32)

Whenk<0:

>>>jnp.tri(3,4,k=-1)Array([[0., 0., 0., 0.],       [1., 0., 0., 0.],       [1., 1., 0., 0.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp