Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.tril_indices

jax.numpy.tril_indices#

jax.numpy.tril_indices(n,k=0,m=None)[source]#

Return the indices of lower triangle of an array of size(n,m).

JAX implementation ofnumpy.tril_indices().

Parameters:
  • n (DimSize) – int. Number of rows of the array for which the indices are returned.

  • k (DimSize) – optional, int, default=0. Specifies the sub-diagonal on and below whichthe indices of lower triangle are returned.k=0 refers to main diagonal,k<0 refers to sub-diagonal below the main diagonal andk>0 refersto sub-diagonal above the main diagonal.

  • m (DimSize |None) – optional, int. Number of columns of the array for which the indices arereturned. If not specified, thenm=n.

Returns:

A tuple of two arrays containing the indices of the lower triangle, one alongeach axis.

Return type:

tuple[Array,Array]

See also

Examples

If onlyn is provided in input, the indices of lower triangle of an arrayof size(n,n) array are returned.

>>>jnp.tril_indices(3)(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

If bothn andm are provided in input, the indices of lower triangleof an(n,m) array are returned.

>>>jnp.tril_indices(3,m=2)(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))

Ifk=1, the indices on and below the first sub-diagonal above the maindiagonal are returned.

>>>jnp.tril_indices(3,k=1)(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))

Ifk=-1, the indices on and below the first sub-diagonal below the maindiagonal are returned.

>>>jnp.tril_indices(3,k=-1)(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))

[8]ページ先頭

©2009-2025 Movatter.jp