jax.numpy.trace
Contents
jax.numpy.trace#
- jax.numpy.trace(a,offset=0,axis1=0,axis2=1,dtype=None,out=None)[source]#
Calculate sum of the diagonal of input along the given axes.
JAX implementation of
numpy.trace().- Parameters:
a (ArrayLike) – input array. Must have
a.ndim>=2.offset (int |ArrayLike) – optional, int, default=0. Diagonal offset from the main diagonal.Can be positive or negative.
axis1 (int) – optional, default=0. The first axis along which to take the sum ofdiagonal. Must be a static integer value.
axis2 (int) – optional, default=1. The second axis along which to take the sum ofdiagonal. Must be a static integer value.
dtype (DTypeLike |None) – optional. The dtype of the output array. Should be provided as staticargument in JIT compilation.
out (None) – Not used by JAX.
- Returns:
An array of dimension x.ndim-2 containing the sum of the diagonal elementsalong axes (axis1, axis2)
- Return type:
See also
jax.numpy.diag(): Returns the specified diagonal or constructs a diagonalarrayjax.numpy.diagonal(): Returns the specified diagonal of an array.jax.numpy.diagflat(): Returns a 2-D array with the flattened input arraylaid out on the diagonal.
Examples
>>>x=jnp.arange(1,9).reshape(2,2,2)>>>xArray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32)>>>jnp.trace(x)Array([ 8, 10], dtype=int32)>>>jnp.trace(x,offset=1)Array([3, 4], dtype=int32)>>>jnp.trace(x,axis1=1,axis2=2)Array([ 5, 13], dtype=int32)>>>jnp.trace(x,offset=1,axis1=1,axis2=2)Array([2, 6], dtype=int32)
