jax.numpy.reshape
Contents
jax.numpy.reshape#
- jax.numpy.reshape(a,shape,order='C',*,copy=None,out_sharding=None)[source]#
Return a reshaped copy of an array.
JAX implementation of
numpy.reshape(), implemented in terms ofjax.lax.reshape().- Parameters:
a (ArrayLike) – input array to reshape
shape (DimSize |Shape) – integer or sequence of integers giving the new shape, which must match thesize of the input array. If any single dimension is given size
-1, it will bereplaced with a value such that the output has the correct size.order (str) –
'F'or'C', specifies whether the reshape should apply column-major(fortran-style,"F") or row-major (C-style,"C") order; default is"C".JAX does not supportorder="A".copy (bool |None) – unused by JAX; JAX always returns a copy, though under JIT the compilermay optimize such copies away.
- Returns:
reshaped copy of input array with the specified shape.
- Return type:
Notes
Unlike
numpy.reshape(),jax.numpy.reshape()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.reshape(): equivalent functionality via an array method.jax.numpy.ravel(): flatten an array into a 1D shape.jax.numpy.squeeze(): remove one or more length-1 axes from an array’s shape.
Examples
>>>x=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.reshape(x,6)Array([1, 2, 3, 4, 5, 6], dtype=int32)>>>jnp.reshape(x,(3,2))Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
You can use
-1to automatically compute a shape that is consistent withthe input size:>>>jnp.reshape(x,-1)# -1 is inferred to be 6Array([1, 2, 3, 4, 5, 6], dtype=int32)>>>jnp.reshape(x,(-1,2))# -1 is inferred to be 3Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
The default ordering of axes in the reshape is C-style row-major ordering.To use Fortran-style column-major ordering, specify
order='F':>>>jnp.reshape(x,6,order='F')Array([1, 4, 2, 5, 3, 6], dtype=int32)>>>jnp.reshape(x,(3,2),order='F')Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
For convenience, this functionality is also available via the
jax.Array.reshape()method:>>>x.reshape(3,2)Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
