jax.numpy.size
Contents
jax.numpy.size#
- jax.numpy.size(a,axis=None)[source]#
Return number of elements along a given axis.
JAX implementation of
numpy.size(). Unlikenp.size, this functionraises aTypeErrorif the input is a collection such as a list ortuple.- Parameters:
a (ArrayLike |SupportsSize |SupportsShape) – array-like object, or any object with a
sizeattribute whenaxisis notspecified, or with ashapeattribute whenaxisis specified.axis (int |Sequence[int]|None) – optional integer or sequence of integers indicating which axis or axes to countelements along.
None(the default) returns the total number of elements.
- Returns:
An integer specifying the number of elements in
a.- Return type:
Examples
Size for arrays:
>>>x=jnp.arange(10)>>>jnp.size(x)10>>>y=jnp.ones((2,3))>>>jnp.size(y)6>>>jnp.size(y,axis=1)3>>>jnp.size(y,axis=(1,))3>>>jnp.size(y,axis=(0,1))6
This also works for scalars:
>>>jnp.size(3.14)1
For arrays, this can also be accessed via the
jax.Array.sizeproperty:>>>y.size6
Contents
