jax.numpy.split
Contents
jax.numpy.split#
- jax.numpy.split(ary,indices_or_sections,axis=0)[source]#
Split an array into sub-arrays.
JAX implementation of
numpy.split().- Parameters:
ary (ArrayLike) – N-dimensional array-like object to split
indices_or_sections (int |Sequence[int]|ArrayLike) –
either a single integer or a sequence of indices.
if
indices_or_sectionsis an integerN, thenN must evenly divideary.shape[axis]andarywill be divided intoN equally-sizedchunks alongaxis.if
indices_or_sectionsis a sequence of integers, then these integersspecify the boundary between unevenly-sized chunks alongaxis; seeexamples below.
axis (int) – the axis along which to split; defaults to 0.
- Returns:
A list of arrays. If
indices_or_sectionsis an integerN, then the list isof lengthN. Ifindices_or_sectionsis a sequenceseq, then the list isis of lengthlen(seq) + 1.- Return type:
Examples
Splitting a 1-dimensional array:
>>>x=jnp.array([1,2,3,4,5,6,7,8,9])
Split into three equal sections:
>>>chunks=jnp.split(x,3)>>>print(*chunks)[1 2 3] [4 5 6] [7 8 9]
Split into sections by index:
>>>chunks=jnp.split(x,[2,7])# [x[0:2], x[2:7], x[7:]]>>>print(*chunks)[1 2] [3 4 5 6 7] [8 9]
Splitting a two-dimensional array along axis 1:
>>>x=jnp.array([[1,2,3,4],...[5,6,7,8]])>>>x1,x2=jnp.split(x,2,axis=1)>>>print(x1)[[1 2] [5 6]]>>>print(x2)[[3 4] [7 8]]
See also
jax.numpy.array_split(): likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.jax.numpy.vsplit(): split vertically, i.e. along axis=0jax.numpy.hsplit(): split horizontally, i.e. along axis=1jax.numpy.dsplit(): split depth-wise, i.e. along axis=2
