Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.count_nonzero

jax.numpy.count_nonzero#

jax.numpy.count_nonzero(a,axis=None,keepdims=False)[source]#

Return the number of nonzero elements along a given axis.

JAX implementation ofnumpy.count_nonzero().

Parameters:
  • a (ArrayLike) – input array.

  • axis (Axis) – optional, int or sequence of ints, default=None. Axis along which thenumber of nonzeros are counted. If None, counts within the flattened array.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.

Returns:

An array with number of nonzeros elements along specified axis of the input.

Return type:

Array

Examples

By default,jnp.count_nonzero counts the nonzero values along all axes.

>>>x=jnp.array([[1,0,0,0],...[0,0,1,0],...[1,1,1,0]])>>>jnp.count_nonzero(x)Array(5, dtype=int32)

Ifaxis=1, counts along axis 1.

>>>jnp.count_nonzero(x,axis=1)Array([1, 1, 3], dtype=int32)

To preserve the dimensions of input, you can setkeepdims=True.

>>>jnp.count_nonzero(x,axis=1,keepdims=True)Array([[1],       [1],       [3]], dtype=int32)

[8]ページ先頭

©2009-2025 Movatter.jp