Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.fill_diagonal

jax.numpy.fill_diagonal#

jax.numpy.fill_diagonal(a,val,wrap=False,*,inplace=True)[source]#

Return a copy of the array with the diagonal overwritten.

JAX implementation ofnumpy.fill_diagonal().

The semantics ofnumpy.fill_diagonal() are to modify arrays in-place, whichis not possible for JAX’s immutable arrays. The JAX version returns a modifiedcopy of the input, and adds theinplace parameter which must be set toFalse` by the user as a reminder of this API difference.

Parameters:
  • a (ArrayLike) – input array. Must havea.ndim>=2. Ifa.ndim>=3, then alldimensions must be the same size.

  • val (ArrayLike) – scalar or array with which to fill the diagonal. If an array, it willbe flattened and repeated to fill the diagonal entries.

  • wrap (bool) – Not implemented by JAX. Only the default value ofFalse is supported.

  • inplace (bool) – must be set to False to indicate that the input is not modifiedin-place, but rather a modified copy is returned.

Returns:

A copy ofa with the diagonal set toval.

Return type:

Array

Examples

>>>x=jnp.zeros((3,3),dtype=int)>>>jnp.fill_diagonal(x,jnp.array([1,2,3]),inplace=False)Array([[1, 0, 0],       [0, 2, 0],       [0, 0, 3]], dtype=int32)

Unlikenumpy.fill_diagonal(), the inputx is not modified.

If the diagonal value has too many entries, it will be truncated

>>>jnp.fill_diagonal(x,jnp.arange(100,200),inplace=False)Array([[100,   0,   0],       [  0, 101,   0],       [  0,   0, 102]], dtype=int32)

If the diagonal has too few entries, it will be repeated:

>>>x=jnp.zeros((4,4),dtype=int)>>>jnp.fill_diagonal(x,jnp.array([3,4]),inplace=False)Array([[3, 0, 0, 0],       [0, 4, 0, 0],       [0, 0, 3, 0],       [0, 0, 0, 4]], dtype=int32)

For non-square arrays, the diagonal of the leading square slice is filled:

>>>x=jnp.zeros((3,5),dtype=int)>>>jnp.fill_diagonal(x,1,inplace=False)Array([[1, 0, 0, 0, 0],       [0, 1, 0, 0, 0],       [0, 0, 1, 0, 0]], dtype=int32)

And for square N-dimensional arrays, the N-dimensional diagonal is filled:

>>>y=jnp.zeros((2,2,2))>>>jnp.fill_diagonal(y,1,inplace=False)Array([[[1., 0.],        [0., 0.]],       [[0., 0.],        [0., 1.]]], dtype=float32)

[8]ページ先頭

©2009-2025 Movatter.jp