jax.numpy.dstack
Contents
jax.numpy.dstack#
- jax.numpy.dstack(tup,dtype=None)[source]#
Stack arrays depth-wise.
JAX implementation of
numpy.dstack().For arrays of three or more dimensions, this is equivalent to
jax.numpy.concatenate()withaxis=2.- Parameters:
tup (np.ndarray |Array |Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape along allbut the third axis. Input arrays will be promoted to at least rank 3. If asingle array is given it will be treated equivalently totup = unstack(tup),but the implementation will avoid explicit unstacking.
dtype (DTypeLike |None) – optional dtype of the resulting array. If not specified, the dtypewill be determined via type promotion rules described inType promotion semantics.
- Returns:
the stacked result.
- Return type:
See also
jax.numpy.stack(): stack along arbitrary axesjax.numpy.concatenate(): concatenation along existing axes.jax.numpy.vstack(): stack vertically, i.e. along axis 0.jax.numpy.hstack(): stack horizontally, i.e. along axis 1.
Examples
Scalar values:
>>>jnp.dstack([1,2,3])Array([[[1, 2, 3]]], dtype=int32, weak_type=True)
1D arrays:
>>>x=jnp.arange(3)>>>y=jnp.ones(3)>>>jnp.dstack([x,y])Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)
2D arrays:
>>>x=x.reshape(1,3)>>>y=y.reshape(1,3)>>>jnp.dstack([x,y])Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)
