jax.numpy.array_split
Contents
jax.numpy.array_split#
- jax.numpy.array_split(ary,indices_or_sections,axis=0)[source]#
Split an array into sub-arrays.
JAX implementation of
numpy.array_split().Refer to the documentation of
jax.numpy.split()for details;array_splitis equivalent tosplit, but allows integerindices_or_sectionswhich doesnot evenly divide the split axis.Examples
>>>x=jnp.array([1,2,3,4,5,6,7,8,9])>>>chunks=jnp.array_split(x,4)>>>print(*chunks)[1 2 3] [4 5] [6 7] [8 9]
See also
jax.numpy.split(): split an array along any axis.jax.numpy.vsplit(): split vertically, i.e. along axis=0jax.numpy.hsplit(): split horizontally, i.e. along axis=1jax.numpy.dsplit(): split depth-wise, i.e. along axis=2
Contents
