jax.numpy.array
Contents
jax.numpy.array#
- jax.numpy.array(object,dtype=None,copy=True,order='K',ndmin=0,*,device=None,out_sharding=None)[source]#
Convert an object to a JAX array.
JAX implementation of
numpy.array().- Parameters:
object (Any) – an object that is convertible to an array. This includes JAXarrays, NumPy arrays, Python scalars, Python collections like listsand tuples, objects with a
__jax_array__method, and objectssupporting the Python buffer protocol.dtype (str |type[Any]|dtype |SupportsDType |None) – optionally specify the dtype of the output array. If notspecified it will be inferred from the input.
copy (bool) – specify whether to force a copy of the input. Default: True.
order (str |None) – not implemented in JAX
ndmin (int) – integer specifying the minimum number of dimensions in theoutput array.
device (Device |Sharding |None) – optional
DeviceorShardingto which the created array will be committed.out_sharding (NamedSharding |PartitionSpec |None) – (optional)
PartitionSpecorNamedShardingrepresenting the sharding of the created array (seeexplicit sharding for more details).This argument exists for consistency with other array creation routines across JAX.Specifying bothout_shardinganddevicewill result in an error.
- Returns:
A JAX array constructed from the input.
- Return type:
See also
jax.numpy.asarray(): likearray, but by default only copieswhen necessary.jax.numpy.from_dlpack(): construct a JAX array from an objectthat implements the dlpack interface.jax.numpy.frombuffer(): construct a JAX array from an objectthat implements the buffer interface.
Examples
Constructing JAX arrays from Python scalars:
>>>jnp.array(True)Array(True, dtype=bool)>>>jnp.array(42)Array(42, dtype=int32, weak_type=True)>>>jnp.array(3.5)Array(3.5, dtype=float32, weak_type=True)>>>jnp.array(1+1j)Array(1.+1.j, dtype=complex64, weak_type=True)
Constructing JAX arrays from Python collections:
>>>jnp.array([1,2,3])# list of ints -> 1D arrayArray([1, 2, 3], dtype=int32)>>>jnp.array([(1,2,3),(4,5,6)])# list of tuples of ints -> 2D arrayArray([[1, 2, 3], [4, 5, 6]], dtype=int32)>>>jnp.array(range(5))Array([0, 1, 2, 3, 4], dtype=int32)
Constructing JAX arrays from NumPy arrays:
>>>jnp.array(np.linspace(0,2,5))Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
Constructing a JAX array via the Python buffer interface, using Python’sbuilt-in
arraymodule.>>>fromarrayimportarray>>>pybuffer=array('i',[2,3,5,7])>>>jnp.array(pybuffer)Array([2, 3, 5, 7], dtype=int32)
