Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.trace

Contents

jax.numpy.trace#

jax.numpy.trace(a,offset=0,axis1=0,axis2=1,dtype=None,out=None)[source]#

Calculate sum of the diagonal of input along the given axes.

JAX implementation ofnumpy.trace().

Parameters:
  • a (ArrayLike) – input array. Must havea.ndim>=2.

  • offset (int |ArrayLike) – optional, int, default=0. Diagonal offset from the main diagonal.Can be positive or negative.

  • axis1 (int) – optional, default=0. The first axis along which to take the sum ofdiagonal. Must be a static integer value.

  • axis2 (int) – optional, default=1. The second axis along which to take the sum ofdiagonal. Must be a static integer value.

  • dtype (DTypeLike |None) – optional. The dtype of the output array. Should be provided as staticargument in JIT compilation.

  • out (None) – Not used by JAX.

Returns:

An array of dimension x.ndim-2 containing the sum of the diagonal elementsalong axes (axis1, axis2)

Return type:

Array

See also

Examples

>>>x=jnp.arange(1,9).reshape(2,2,2)>>>xArray([[[1, 2],        [3, 4]],       [[5, 6],        [7, 8]]], dtype=int32)>>>jnp.trace(x)Array([ 8, 10], dtype=int32)>>>jnp.trace(x,offset=1)Array([3, 4], dtype=int32)>>>jnp.trace(x,axis1=1,axis2=2)Array([ 5, 13], dtype=int32)>>>jnp.trace(x,offset=1,axis1=1,axis2=2)Array([2, 6], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp