Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.unstack

Contents

jax.numpy.unstack#

jax.numpy.unstack(x,/,*,axis=0)[source]#

Unstack an array along an axis.

JAX implementation ofarray_api.unstack().

Parameters:
  • x (ArrayLike) – array to unstack. Must havex.ndim>=1.

  • axis (int) – integer axis along which to unstack. Must satisfy-x.ndim<=axis<x.ndim.

Returns:

tuple of unstacked arrays.

Return type:

tuple[Array, …]

See also

Examples

>>>arr=jnp.array([[1,2,3],...[4,5,6]])>>>arrs=jnp.unstack(arr)>>>print(*arrs)[1 2 3] [4 5 6]

stack() provides the inverse of this:

>>>jnp.stack(arrs)Array([[1, 2, 3],       [4, 5, 6]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp