Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.array_split

Contents

jax.numpy.array_split#

jax.numpy.array_split(ary,indices_or_sections,axis=0)[source]#

Split an array into sub-arrays.

JAX implementation ofnumpy.array_split().

Refer to the documentation ofjax.numpy.split() for details;array_splitis equivalent tosplit, but allows integerindices_or_sections which doesnot evenly divide the split axis.

Examples

>>>x=jnp.array([1,2,3,4,5,6,7,8,9])>>>chunks=jnp.array_split(x,4)>>>print(*chunks)[1 2 3] [4 5] [6 7] [8 9]

See also

Parameters:
  • ary (ArrayLike)

  • indices_or_sections (int |Sequence[int]|ArrayLike)

  • axis (int)

Return type:

list[Array]

Contents

[8]ページ先頭

©2009-2025 Movatter.jp