jax.numpy.shape
Contents
jax.numpy.shape#
- jax.numpy.shape(a)[source]#
Return the shape an array.
JAX implementation of
numpy.shape(). Unlikenp.shape, this functionraises aTypeErrorif the input is a collection such as a list ortuple.- Parameters:
a (ArrayLike |SupportsShape) – array-like object, or any object with a
shapeattribute.- Returns:
An tuple of integers representing the shape of
a.- Return type:
Examples
Shape for arrays:
>>>x=jnp.arange(10)>>>jnp.shape(x)(10,)>>>y=jnp.ones((2,3))>>>jnp.shape(y)(2, 3)
This also works for scalars:
>>>jnp.shape(3.14)()
For arrays, this can also be accessed via the
jax.Array.shapeproperty:>>>x.shape(10,)
Contents
