Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.dstack

Contents

jax.numpy.dstack#

jax.numpy.dstack(tup,dtype=None)[source]#

Stack arrays depth-wise.

JAX implementation ofnumpy.dstack().

For arrays of three or more dimensions, this is equivalent tojax.numpy.concatenate() withaxis=2.

Parameters:
  • tup (np.ndarray |Array |Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape along allbut the third axis. Input arrays will be promoted to at least rank 3. If asingle array is given it will be treated equivalently totup = unstack(tup),but the implementation will avoid explicit unstacking.

  • dtype (DTypeLike |None) – optional dtype of the resulting array. If not specified, the dtypewill be determined via type promotion rules described inType promotion semantics.

Returns:

the stacked result.

Return type:

Array

See also

Examples

Scalar values:

>>>jnp.dstack([1,2,3])Array([[[1, 2, 3]]], dtype=int32, weak_type=True)

1D arrays:

>>>x=jnp.arange(3)>>>y=jnp.ones(3)>>>jnp.dstack([x,y])Array([[[0., 1.],        [1., 1.],        [2., 1.]]], dtype=float32)

2D arrays:

>>>x=x.reshape(1,3)>>>y=y.reshape(1,3)>>>jnp.dstack([x,y])Array([[[0., 1.],        [1., 1.],        [2., 1.]]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp