Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.sum

Contents

jax.numpy.sum#

jax.numpy.sum(a,axis=None,dtype=None,out=None,keepdims=False,initial=None,where=None,promote_integers=True)[source]#

Sum of the elements of the array over a given axis.

JAX implementation ofnumpy.sum().

Parameters:
  • a (ArrayLike) – Input array.

  • axis (Axis) – int or array, default=None. Axis along which the sum to be computed.If None, the sum is computed along all the axes.

  • dtype (DTypeLike |None) – The type of the output array. Default=None.

  • out (None) – Unused by JAX

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.

  • initial (ArrayLike |None) – int or array, Default=None. Initial value for the sum.

  • where (ArrayLike |None) – int or array, default=None. The elements to be used in the sum. Arrayshould be broadcast compatible to the input.

  • promote_integers (bool) – bool, default=True. If True, then integer inputs will bepromoted to the widest available integer dtype, following numpy’s behavior.If False, the result will have the same dtype as the input.promote_integers is ignored ifdtype is specified.

Returns:

An array of the sum along the given axis.

Return type:

Array

See also

Examples

By default, the sum is computed along all the axes.

>>>x=jnp.array([[1,3,4,2],...[5,2,6,3],...[8,1,3,9]])>>>jnp.sum(x)Array(47, dtype=int32)

Ifaxis=1, the sum is computed along axis 1.

>>>jnp.sum(x,axis=1)Array([10, 16, 21], dtype=int32)

Ifkeepdims=True,ndim of the output is equal to that of the input.

>>>jnp.sum(x,axis=1,keepdims=True)Array([[10],       [16],       [21]], dtype=int32)

To include only specific elements in the sum, you can usewhere.

>>>where=jnp.array([[0,0,1,0],...[0,0,1,1],...[1,1,1,0]],dtype=bool)>>>jnp.sum(x,axis=1,keepdims=True,where=where)Array([[ 4],       [ 9],       [12]], dtype=int32)>>>where=jnp.array([[False],...[False],...[False]])>>>jnp.sum(x,axis=0,keepdims=True,where=where)Array([[0, 0, 0, 0]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp