Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.diagflat

Contents

jax.numpy.diagflat#

jax.numpy.diagflat(v,k=0)[source]#

Return a 2-D array with the flattened input array laid out on the diagonal.

JAX implementation ofnumpy.diagflat().

This differs fromnp.diagflat for some scalar values ofv. JAX always returnsa two-dimensional array, whereas NumPy may return a scalar depending on the typeofv.

Parameters:
  • v (ArrayLike) – Input array. Can be N-dimensional but is flattened to 1D.

  • k (int) – optional, default=0. Diagonal offset. Positive values place the diagonalabove the main diagonal, negative values place it below the main diagonal.

Returns:

A 2D array with the input elements placed along the diagonal with thespecified offset (k). The remaining entries are filled with zeros.

Return type:

Array

Examples

>>>jnp.diagflat(jnp.array([1,2,3]))Array([[1, 0, 0],       [0, 2, 0],       [0, 0, 3]], dtype=int32)>>>jnp.diagflat(jnp.array([1,2,3]),k=1)Array([[0, 1, 0, 0],       [0, 0, 2, 0],       [0, 0, 0, 3],       [0, 0, 0, 0]], dtype=int32)>>>a=jnp.array([[1,2],...[3,4]])>>>jnp.diagflat(a)Array([[1, 0, 0, 0],       [0, 2, 0, 0],       [0, 0, 3, 0],       [0, 0, 0, 4]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp