Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.nancumprod

Contents

jax.numpy.nancumprod#

jax.numpy.nancumprod(a,axis=None,dtype=None,out=None)[source]#

Cumulative product of elements along an axis, ignoring NaN values.

JAX implementation ofnumpy.nancumprod().

Parameters:
  • a (ArrayLike) – N-dimensional array to be accumulated.

  • axis (int |None) – integer axis along which to accumulate. If None (default), thenarray will be flattened and accumulated along the flattened axis.

  • dtype (DTypeLike |None) – optionally specify the dtype of the output. If not specified,then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Returns:

An array containing the accumulated product along the given axis.

Return type:

Array

See also

Examples

>>>x=jnp.array([[1.,2.,jnp.nan],...[4.,jnp.nan,6.]])

The standard cumulative product will propagate NaN values:

>>>jnp.cumprod(x)Array([ 1.,  2., nan, nan, nan, nan], dtype=float32)

nancumprod() will ignore NaN values, effectively replacingthem with ones:

>>>jnp.nancumprod(x)Array([ 1.,  2.,  2.,  8.,  8., 48.], dtype=float32)

Cumulative product along axis 1:

>>>jnp.nancumprod(x,axis=1)Array([[ 1.,  2.,  2.],       [ 4.,  4., 24.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp