Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.array

Contents

jax.numpy.array#

jax.numpy.array(object,dtype=None,copy=True,order='K',ndmin=0,*,device=None,out_sharding=None)[source]#

Convert an object to a JAX array.

JAX implementation ofnumpy.array().

Parameters:
  • object (Any) – an object that is convertible to an array. This includes JAXarrays, NumPy arrays, Python scalars, Python collections like listsand tuples, objects with a__jax_array__ method, and objectssupporting the Python buffer protocol.

  • dtype (str |type[Any]|dtype |SupportsDType |None) – optionally specify the dtype of the output array. If notspecified it will be inferred from the input.

  • copy (bool) – specify whether to force a copy of the input. Default: True.

  • order (str |None) – not implemented in JAX

  • ndmin (int) – integer specifying the minimum number of dimensions in theoutput array.

  • device (Device |Sharding |None) – optionalDevice orShardingto which the created array will be committed.

  • out_sharding (NamedSharding |PartitionSpec |None) – (optional)PartitionSpec orNamedShardingrepresenting the sharding of the created array (seeexplicit sharding for more details).This argument exists for consistency with other array creation routines across JAX.Specifying bothout_sharding anddevice will result in an error.

Returns:

A JAX array constructed from the input.

Return type:

Array

See also

Examples

Constructing JAX arrays from Python scalars:

>>>jnp.array(True)Array(True, dtype=bool)>>>jnp.array(42)Array(42, dtype=int32, weak_type=True)>>>jnp.array(3.5)Array(3.5, dtype=float32, weak_type=True)>>>jnp.array(1+1j)Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

>>>jnp.array([1,2,3])# list of ints -> 1D arrayArray([1, 2, 3], dtype=int32)>>>jnp.array([(1,2,3),(4,5,6)])# list of tuples of ints -> 2D arrayArray([[1, 2, 3],       [4, 5, 6]], dtype=int32)>>>jnp.array(range(5))Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

>>>jnp.array(np.linspace(0,2,5))Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’sbuilt-inarray module.

>>>fromarrayimportarray>>>pybuffer=array('i',[2,3,5,7])>>>jnp.array(pybuffer)Array([2, 3, 5, 7], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp