Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.split

Contents

jax.numpy.split#

jax.numpy.split(ary,indices_or_sections,axis=0)[source]#

Split an array into sub-arrays.

JAX implementation ofnumpy.split().

Parameters:
  • ary (ArrayLike) – N-dimensional array-like object to split

  • indices_or_sections (int |Sequence[int]|ArrayLike) –

    either a single integer or a sequence of indices.

    • ifindices_or_sections is an integerN, thenN must evenly divideary.shape[axis] andary will be divided intoN equally-sizedchunks alongaxis.

    • ifindices_or_sections is a sequence of integers, then these integersspecify the boundary between unevenly-sized chunks alongaxis; seeexamples below.

  • axis (int) – the axis along which to split; defaults to 0.

Returns:

A list of arrays. Ifindices_or_sections is an integerN, then the list isof lengthN. Ifindices_or_sections is a sequenceseq, then the list isis of lengthlen(seq) + 1.

Return type:

list[Array]

Examples

Splitting a 1-dimensional array:

>>>x=jnp.array([1,2,3,4,5,6,7,8,9])

Split into three equal sections:

>>>chunks=jnp.split(x,3)>>>print(*chunks)[1 2 3] [4 5 6] [7 8 9]

Split into sections by index:

>>>chunks=jnp.split(x,[2,7])# [x[0:2], x[2:7], x[7:]]>>>print(*chunks)[1 2] [3 4 5 6 7] [8 9]

Splitting a two-dimensional array along axis 1:

>>>x=jnp.array([[1,2,3,4],...[5,6,7,8]])>>>x1,x2=jnp.split(x,2,axis=1)>>>print(x1)[[1 2] [5 6]]>>>print(x2)[[3 4] [7 8]]

See also

Contents

[8]ページ先頭

©2009-2025 Movatter.jp