jax.numpy.r_
Contents
jax.numpy.r_#
- jax.numpy.r_=<jax._src.numpy.index_tricks.RClassobject>#
Concatenate slices, scalars and array-like objects along the first axis.
LAX-backend implementation of
numpy.r_.See also
jnp.c_: Concatenates slices, scalars and array-like objects along the last axis.Examples
Passing slices in the form
[start:stop:step]generatesjnp.arangeobjects:>>>jnp.r_[-1:5:1,0,0,jnp.array([1,2,3])]Array([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
An imaginary value for
stepwill create ajnp.linspaceobject instead,which includes the right endpoint:>>>jnp.r_[-1:1:6j,0,jnp.array([1,2,3])]Array([-1. , -0.6, -0.2, 0.2, 0.6, 1. , 0. , 1. , 2. , 3. ], dtype=float32)
Use a string directive of the form
"axis,dims,trans1d"as the first argument tospecify concatenation axis, minimum number of dimensions, and the position of theupgraded array’s original dimensions in the resulting array’s shape tuple:>>>jnp.r_['0,2',[1,2,3],[4,5,6]]# concatenate along first axis, 2D outputArray([[1, 2, 3], [4, 5, 6]], dtype=int32)
>>>jnp.r_['0,2,0',[1,2,3],[4,5,6]]# push last input axis to the frontArray([[1], [2], [3], [4], [5], [6]], dtype=int32)
Negative values for
trans1doffset the last axis towards the startof the shape tuple:>>>jnp.r_['0,2,-2',[1,2,3],[4,5,6]]Array([[1], [2], [3], [4], [5], [6]], dtype=int32)
Use the special directives
"r"or"c"as the first argument on flat inputsto create an array with an extra row or column axis, respectively:>>>jnp.r_['r',[1,2,3],[4,5,6]]Array([[1, 2, 3, 4, 5, 6]], dtype=int32)
>>>jnp.r_['c',[1,2,3],[4,5,6]]Array([[1], [2], [3], [4], [5], [6]], dtype=int32)
For higher-dimensional inputs (
dim>=2), both directives"r"and"c"give the same result.
