jax.numpy.tril_indices
Contents
jax.numpy.tril_indices#
- jax.numpy.tril_indices(n,k=0,m=None)[source]#
Return the indices of lower triangle of an array of size
(n,m).JAX implementation of
numpy.tril_indices().- Parameters:
n (DimSize) – int. Number of rows of the array for which the indices are returned.
k (DimSize) – optional, int, default=0. Specifies the sub-diagonal on and below whichthe indices of lower triangle are returned.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refersto sub-diagonal above the main diagonal.m (DimSize |None) – optional, int. Number of columns of the array for which the indices arereturned. If not specified, then
m=n.
- Returns:
A tuple of two arrays containing the indices of the lower triangle, one alongeach axis.
- Return type:
See also
jax.numpy.triu_indices(): Returns the indices of upper triangle of anarray of size(n,m).jax.numpy.triu_indices_from(): Returns the indices of upper triangleof a given array.jax.numpy.tril_indices_from(): Returns the indices of lower triangleof a given array.
Examples
If only
nis provided in input, the indices of lower triangle of an arrayof size(n,n)array are returned.>>>jnp.tril_indices(3)(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
If both
nandmare provided in input, the indices of lower triangleof an(n,m)array are returned.>>>jnp.tril_indices(3,m=2)(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))
If
k=1, the indices on and below the first sub-diagonal above the maindiagonal are returned.>>>jnp.tril_indices(3,k=1)(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
If
k=-1, the indices on and below the first sub-diagonal below the maindiagonal are returned.>>>jnp.tril_indices(3,k=-1)(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
