Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.nansum

Contents

jax.numpy.nansum#

jax.numpy.nansum(a,axis=None,dtype=None,out=None,keepdims=False,initial=None,where=None)[source]#

Return the sum of the array elements along a given axis, ignoring NaNs.

JAX implementation ofnumpy.nansum().

Parameters:
  • a (ArrayLike) – Input array.

  • axis (Axis) – int or sequence of ints, default=None. Axis along which the sum iscomputed. If None, the sum is computed along the flattened array.

  • dtype (DTypeLike |None) – The type of the output array. Default=None.

  • keepdims (bool) – bool, default=False. If True, reduced axes are left in the resultwith size 1.

  • initial (ArrayLike |None) – int or array, default=None. Initial value for the sum.

  • where (ArrayLike |None) – array of boolean dtype, default=None. The elements to be used in thesum. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Returns:

An array containing the sum of array elements along the given axis, ignoringNaNs. If all elements along the given axis are NaNs, returns 0.

Return type:

Array

See also

Examples

By default,jnp.nansum computes the sum of elements along the flattenedarray.

>>>nan=jnp.nan>>>x=jnp.array([[3,nan,4,5],...[nan,-2,nan,7],...[2,1,6,nan]])>>>jnp.nansum(x)Array(26., dtype=float32)

Ifaxis=1, the sum will be computed along axis 1.

>>>jnp.nansum(x,axis=1)Array([12.,  5.,  9.], dtype=float32)

Ifkeepdims=True,ndim of the output will be same of that of the input.

>>>jnp.nansum(x,axis=1,keepdims=True)Array([[12.],       [ 5.],       [ 9.]], dtype=float32)

To include only specific elements in computing the sum, you can usewhere.

>>>where=jnp.array([[1,0,1,0],...[0,0,1,1],...[1,1,1,0]],dtype=bool)>>>jnp.nansum(x,axis=1,keepdims=True,where=where)Array([[7.],       [7.],       [9.]], dtype=float32)

Ifwhere isFalse at all elements,jnp.nansum returns 0 alongthe given axis.

>>>where=jnp.array([[False],...[False],...[False]])>>>jnp.nansum(x,axis=0,keepdims=True,where=where)Array([[0., 0., 0., 0.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp