jax.numpy.column_stack
Contents
jax.numpy.column_stack#
- jax.numpy.column_stack(tup)[source]#
Stack arrays column-wise.
JAX implementation of
numpy.column_stack().For arrays of two or more dimensions, this is equivalent to
jax.numpy.concatenate()withaxis=1.- Parameters:
tup (np.ndarray |Array |Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same leading dimension.Input arrays will be promoted to at least rank 2. If a single array is givenit will be treated equivalently totup = unstack(tup), but the implementationwill avoid explicit unstacking.
dtype – optional dtype of the resulting array. If not specified, the dtypewill be determined via type promotion rules described inType promotion semantics.
- Returns:
the stacked result.
- Return type:
See also
jax.numpy.stack(): stack along arbitrary axesjax.numpy.concatenate(): concatenation along existing axes.jax.numpy.vstack(): stack vertically, i.e. along axis 0.jax.numpy.hstack(): stack horizontally, i.e. along axis 1.jax.numpy.dstack(): stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
>>>jnp.column_stack([1,2,3])Array([[1, 2, 3]], dtype=int32, weak_type=True)
1D arrays:
>>>x=jnp.arange(3)>>>y=jnp.ones(3)>>>jnp.column_stack([x,y])Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
2D arrays:
>>>x=x.reshape(3,1)>>>y=y.reshape(3,1)>>>jnp.column_stack([x,y])Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
