Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.diag_indices

jax.numpy.diag_indices#

jax.numpy.diag_indices(n,ndim=2)[source]#

Return indices for accessing the main diagonal of a multidimensional array.

JAX implementation ofnumpy.diag_indices().

Parameters:
  • n (int) – int. The size of each dimension of the square array.

  • ndim (int) – optional, int, default=2. The number of dimensions of the array.

Returns:

A tuple of arrays, each of lengthn, containing the indices to accessthe main diagonal.

Return type:

tuple[Array, …]

Examples

>>>jnp.diag_indices(3)(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))>>>jnp.diag_indices(4,ndim=3)(Array([0, 1, 2, 3], dtype=int32),Array([0, 1, 2, 3], dtype=int32),Array([0, 1, 2, 3], dtype=int32))

[8]ページ先頭

©2009-2025 Movatter.jp