jax.numpy.unstack
Contents
jax.numpy.unstack#
- jax.numpy.unstack(x,/,*,axis=0)[source]#
Unstack an array along an axis.
JAX implementation of
array_api.unstack().- Parameters:
x (ArrayLike) – array to unstack. Must have
x.ndim>=1.axis (int) – integer axis along which to unstack. Must satisfy
-x.ndim<=axis<x.ndim.
- Returns:
tuple of unstacked arrays.
- Return type:
See also
jax.numpy.stack(): inverse ofunstackjax.numpy.split(): split array into batches along an axis.
Examples
>>>arr=jnp.array([[1,2,3],...[4,5,6]])>>>arrs=jnp.unstack(arr)>>>print(*arrs)[1 2 3] [4 5 6]
stack()provides the inverse of this:>>>jnp.stack(arrs)Array([[1, 2, 3], [4, 5, 6]], dtype=int32)
Contents
