jax.numpy.diag_indices
Contents
jax.numpy.diag_indices#
- jax.numpy.diag_indices(n,ndim=2)[source]#
Return indices for accessing the main diagonal of a multidimensional array.
JAX implementation of
numpy.diag_indices().- Parameters:
- Returns:
A tuple of arrays, each of lengthn, containing the indices to accessthe main diagonal.
- Return type:
Examples
>>>jnp.diag_indices(3)(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))>>>jnp.diag_indices(4,ndim=3)(Array([0, 1, 2, 3], dtype=int32),Array([0, 1, 2, 3], dtype=int32),Array([0, 1, 2, 3], dtype=int32))
Contents
