jax.numpy.atleast_3d
Contents
jax.numpy.atleast_3d#
- jax.numpy.atleast_3d(*arys)[source]#
Convert inputs to arrays with at least 3 dimensions.
JAX implementation of
numpy.atleast_3d().- Parameters:
arguments. (zero ormore arraylike)
arys (ArrayLike)
- Returns:
an array or list of arrays corresponding to the input values. Arraysof shape
()are converted to shape(1,1,1), 1D arrays ofshape(N,)are converted to shape(1,N,1), 2D arrays ofshape(M,N)are converted to shape(M,N,1), and arraysof all other shapes are returned unchanged.- Return type:
Examples
Scalar arguments are converted to 3D, size-1 arrays:
>>>x=jnp.float32(1.0)>>>jnp.atleast_3d(x)Array([[[1.]]], dtype=float32)
1D arrays have a unit dimension prepended and appended:
>>>y=jnp.arange(4)>>>jnp.atleast_3d(y).shape(1, 4, 1)
2D arrays have a unit dimension appended:
>>>z=jnp.ones((2,3))>>>jnp.atleast_3d(z).shape(2, 3, 1)
Multiple arguments can be passed to the function at once, in whichcase a list of results is returned:
>>>x3,y3=jnp.atleast_3d(x,y)>>>print(x3)[[[1.]]]>>>print(y3)[[[0] [1] [2] [3]]]
Contents
