Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.nanprod

Contents

jax.numpy.nanprod#

jax.numpy.nanprod(a,axis=None,dtype=None,out=None,keepdims=False,initial=None,where=None)[source]#

Return the product of the array elements along a given axis, ignoring NaNs.

JAX implementation ofnumpy.nanprod().

Parameters:
  • a (ArrayLike) – Input array.

  • axis (Axis) – int or sequence of ints, default=None. Axis along which the product iscomputed. If None, the product is computed along the flattened array.

  • dtype (DTypeLike |None) – The type of the output array. Default=None.

  • keepdims (bool) – bool, default=False. If True, reduced axes are left in the resultwith size 1.

  • initial (ArrayLike |None) – int or array, default=None. Initial value for the product.

  • where (ArrayLike |None) – array of boolean dtype, default=None. The elements to be used in theproduct. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Returns:

An array containing the product of array elements along the given axis,ignoring NaNs. If all elements along the given axis are NaNs, returns 1.

Return type:

Array

See also

Examples

By default,jnp.nanprod computes the product of elements along the flattenedarray.

>>>nan=jnp.nan>>>x=jnp.array([[nan,3,4,nan],...[5,nan,1,3],...[2,1,nan,1]])>>>jnp.nanprod(x)Array(360., dtype=float32)

Ifaxis=1, the product will be computed along axis 1.

>>>jnp.nanprod(x,axis=1)Array([12., 15.,  2.], dtype=float32)

Ifkeepdims=True,ndim of the output will be same of that of the input.

>>>jnp.nanprod(x,axis=1,keepdims=True)Array([[12.],       [15.],       [ 2.]], dtype=float32)

To include only specific elements in computing the maximum, you can usewhere.

>>>where=jnp.array([[1,0,1,0],...[0,0,1,1],...[1,1,1,0]],dtype=bool)>>>jnp.nanprod(x,axis=1,keepdims=True,where=where)Array([[4.],       [3.],       [2.]], dtype=float32)

Ifwhere isFalse at all elements,jnp.nanprod returns 1 alongthe given axis.

>>>where=jnp.array([[False],...[False],...[False]])>>>jnp.nanprod(x,axis=0,keepdims=True,where=where)Array([[1., 1., 1., 1.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp