Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.prod

Contents

jax.numpy.prod#

jax.numpy.prod(a,axis=None,dtype=None,out=None,keepdims=False,initial=None,where=None,promote_integers=True)[source]#

Return product of the array elements over a given axis.

JAX implementation ofnumpy.prod().

Parameters:
  • a (ArrayLike) – Input array.

  • axis (Axis) – int or array, default=None. Axis along which the product to be computed.If None, the product is computed along all the axes.

  • dtype (DTypeLike |None) – The type of the output array. Default=None.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.

  • initial (ArrayLike |None) – int or array, Default=None. Initial value for the product.

  • where (ArrayLike |None) – int or array, default=None. The elements to be used in the product.Array should be broadcast compatible to the input.

  • promote_integers (bool) – bool, default=True. If True, then integer inputs will bepromoted to the widest available integer dtype, following numpy’s behavior.If False, the result will have the same dtype as the input.promote_integers is ignored ifdtype is specified.

  • out (None) – Unused by JAX.

Returns:

An array of the product along the given axis.

Return type:

Array

See also

Examples

By default,jnp.prod computes along all the axes.

>>>x=jnp.array([[1,3,4,2],...[5,2,1,3],...[2,1,3,1]])>>>jnp.prod(x)Array(4320, dtype=int32)

Ifaxis=1, product is computed along axis 1.

>>>jnp.prod(x,axis=1)Array([24, 30,  6], dtype=int32)

Ifkeepdims=True,ndim of the output is equal to that of the input.

>>>jnp.prod(x,axis=1,keepdims=True)Array([[24],       [30],       [ 6]], dtype=int32)

To include only specific elements in the sum, you can use a``where``.

>>>where=jnp.array([[1,0,1,0],...[0,0,1,1],...[1,1,1,0]],dtype=bool)>>>jnp.prod(x,axis=1,keepdims=True,where=where)Array([[4],       [3],       [6]], dtype=int32)>>>where=jnp.array([[False],...[False],...[False]])>>>jnp.prod(x,axis=1,keepdims=True,where=where)Array([[1],       [1],       [1]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp