jax.numpy.ravel
Contents
jax.numpy.ravel#
- jax.numpy.ravel(a,order='C',*,out_sharding=None)[source]#
Flatten array into a 1-dimensional shape.
JAX implementation of
numpy.ravel(), implemented in terms ofjax.lax.reshape().ravel(arr,order=order)is equivalent toreshape(arr,-1,order=order).- Parameters:
a (ArrayLike) – array to be flattened.
order (str) –
'F'or'C', specifies whether the reshape should apply column-major(fortran-style,"F") or row-major (C-style,"C") order; default is"C".JAX does not supportorder=”A” ororder=”K”.
- Returns:
flattened copy of input array.
- Return type:
Notes
Unlike
numpy.ravel(),jax.numpy.ravel()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.ravel(): equivalent functionality via an array method.jax.numpy.reshape(): general array reshape.
Examples
>>>x=jnp.array([[1,2,3],...[4,5,6]])
By default, ravel in C-style, row-major order
>>>jnp.ravel(x)Array([1, 2, 3, 4, 5, 6], dtype=int32)
Optionally ravel in Fortran-style, column-major:
>>>jnp.ravel(x,order='F')Array([1, 4, 2, 5, 3, 6], dtype=int32)
For convenience, the same functionality is available via the
jax.Array.ravel()method:>>>x.ravel()Array([1, 2, 3, 4, 5, 6], dtype=int32)
