jax.numpy.hsplit
Contents
jax.numpy.hsplit#
- jax.numpy.hsplit(ary,indices_or_sections)[source]#
Split an array into sub-arrays horizontally.
JAX implementation of
numpy.hsplit().Refer to the documentation of
jax.numpy.split()for details.hsplitisequivalent tosplitwithaxis=1, oraxis=0for one-dimensional arrays.Examples
1D array:
>>>x=jnp.array([1,2,3,4,5,6])>>>x1,x2=jnp.hsplit(x,2)>>>print(x1,x2)[1 2 3] [4 5 6]
2D array:
>>>x=jnp.array([[1,2,3,4],...[5,6,7,8]])>>>x1,x2=jnp.hsplit(x,2)>>>print(x1)[[1 2] [5 6]]>>>print(x2)[[3 4] [7 8]]
See also
jax.numpy.split(): split an array along any axis.jax.numpy.vsplit(): split vertically, i.e. along axis=0jax.numpy.dsplit(): split depth-wise, i.e. along axis=2jax.numpy.array_split(): likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.
Contents
