Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.sharding module

Contents

jax.sharding module#

Classes#

classjax.sharding.Sharding(*args,**kwargs)#

Describes how ajax.Array is laid out across devices.

propertyaddressable_devices:set[Device]#

The set of devices in theSharding that are addressable by thecurrent process.

addressable_devices_indices_map(global_shape)[source]#

A mapping from addressable devices to the slice of array data each contains.

addressable_devices_indices_map contains that part ofdevice_indices_map that applies to the addressable devices.

Parameters:

global_shape (Shape)

Return type:

Mapping[Device, Index | None]

propertydevice_set:set[Device][source]#

The set of devices that thisSharding spans.

In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.

Parameters:

global_shape (Shape)

Return type:

Mapping[Device, Index]

is_equivalent_to(other,ndim)[source]#

ReturnsTrue if two shardings are equivalent.

Two shardings are equivalent if they place the same logical array shards onthe same devices.

Parameters:
Return type:

bool

propertyis_fully_addressable:bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all ofthe devices named in theSharding.is_fully_addressable isequivalent to “is_local” in multi-process JAX.

propertyis_fully_replicated:bool[source]#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of theentire data.

propertymemory_kind:str|None[source]#

Returns the memory kind of the sharding.

propertynum_devices:int[source]#

Number of devices that the sharding contains.

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated fromglobal_shape and the properties of the sharding.

Parameters:

global_shape (Shape)

Return type:

Shape

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

Sharding

classjax.sharding.SingleDeviceSharding(*args,**kwargs)#

Bases:Sharding

ASharding that places its data on a single device.

Parameters:

device – A singleDevice.

Examples

>>>single_device_sharding=jax.sharding.SingleDeviceSharding(...jax.devices()[0])
propertydevice_set:set[Device][source]#

The set of devices that thisSharding spans.

In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.

Parameters:

global_shape (Shape)

Return type:

Mapping[Device, Index]

propertyis_fully_addressable:bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all ofthe devices named in theSharding.is_fully_addressable isequivalent to “is_local” in multi-process JAX.

propertyis_fully_replicated:bool[source]#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of theentire data.

propertymemory_kind:str|None[source]#

Returns the memory kind of the sharding.

propertynum_devices:int[source]#

Number of devices that the sharding contains.

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

SingleDeviceSharding

classjax.sharding.NamedSharding(*args,**kwargs)#

Bases:Sharding

ANamedSharding expresses sharding using named axes.

ANamedSharding is a pair of aMesh of devices andPartitionSpec which describes how to shard an array across thatmesh.

AMesh is a multidimensional NumPy array of JAX devices,where each axis of the mesh has a name, e.g.'x' or'y'.

APartitionSpec is a tuple, whose elements can be aNone,a mesh axis, or a tuple of mesh axes. Each element describes how an inputdimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x','y') says that the first dimension of datais sharded acrossx axis of the mesh, and the second dimension is shardedacrossy axis of the mesh.

TheDistributed arrays and automatic parallelizationandExplicit Sharding tutorials have more details and diagrams thatexplain howMesh andPartitionSpec are used.

Parameters:

Examples

>>>fromjax.shardingimportMesh>>>fromjax.shardingimportPartitionSpecasP>>>mesh=Mesh(np.array(jax.devices()).reshape(2,4),('x','y'))>>>spec=P('x','y')>>>named_sharding=jax.sharding.NamedSharding(mesh,spec)
propertyaddressable_devices:set[Device][source]#

The set of devices in theSharding that are addressable by thecurrent process.

propertydevice_set:set[Device][source]#

The set of devices that thisSharding spans.

In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.

propertyis_fully_addressable:bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all ofthe devices named in theSharding.is_fully_addressable isequivalent to “is_local” in multi-process JAX.

propertyis_fully_replicated:bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of theentire data.

propertymemory_kind:str|None[source]#

Returns the memory kind of the sharding.

propertymesh#

(self) -> object

propertynum_devices:int[source]#

Number of devices that the sharding contains.

propertyspec#

(self) -> jax::PartitionSpec

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

NamedSharding

classjax.sharding.PmapSharding(*args,**kwargs)#

Bases:Sharding

Describes a sharding used byjax.pmap().

classmethoddefault(shape,sharded_dim=0,devices=None)[source]#

Creates aPmapSharding which matches the default placementused byjax.pmap().

Parameters:
  • shape (Shape) – The shape of the input array.

  • sharded_dim (int |None) – Dimension the input array is sharded on. Defaults to 0.

  • devices (Sequence[xc.Device]|None) – Optional sequence of devices to use. If omitted, the implicitdevice order used by pmap is used, which is the order ofjax.local_devices().

Return type:

PmapSharding

propertydevice_set:set[Device]#

The set of devices that thisSharding spans.

In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.

propertydevices#

(self) -> numpy.ndarray

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.

Parameters:

global_shape (Shape)

Return type:

Mapping[Device, Index]

is_equivalent_to(other,ndim)[source]#

ReturnsTrue if two shardings are equivalent.

Two shardings are equivalent if they place the same logical array shards onthe same devices.

Parameters:
Return type:

bool

propertyis_fully_addressable:bool#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all ofthe devices named in theSharding.is_fully_addressable isequivalent to “is_local” in multi-process JAX.

propertyis_fully_replicated:bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of theentire data.

propertymemory_kind:str|None[source]#

Returns the memory kind of the sharding.

propertynum_devices:int[source]#

Number of devices that the sharding contains.

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated fromglobal_shape and the properties of the sharding.

Parameters:

global_shape (Shape)

Return type:

Shape

propertysharding_spec#

(self) -> jax::ShardingSpec

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

classjax.sharding.PartitionSpec(*args,**kwargs)#

Tuple describing how to partition an array across a mesh of devices.

Each element is eitherNone, a string, or a tuple of strings.See the documentation ofjax.sharding.NamedSharding for more details.

This class exists so JAX’s pytree utilities can distinguish a partitionspecifications from tuples that should be treated as pytrees.

propertyreduced#

(self) -> frozenset

propertyunreduced#

(self) -> frozenset

classjax.sharding.Mesh(devices,axis_names,axis_types=None)[source]#

Declare the hardware resources available in the scope of this manager.

SeeDistributed arrays and automatic parallelization andExplicit Sharding tutorials.

Parameters:
  • devices (np.ndarray) – A NumPy ndarray object containing JAX device objects (asobtained e.g. fromjax.devices()).

  • axis_names (tuple[MeshAxisName,...]) – A sequence of resource axis names to be assigned to thedimensions of thedevices argument. Its length should match therank ofdevices.

  • axis_types (tuple[AxisType,...]) – and optional tuple ofjax.sharding.AxisType entries corresponding totheaxis_names. SeeExplicit Sharding for more information.

Examples

>>>fromjax.shardingimportMesh>>>fromjax.shardingimportPartitionSpecasP,NamedSharding>>>importnumpyasnp...>>># Declare a 2D mesh with axes `x` and `y`.>>>devices=np.array(jax.devices()).reshape(4,2)>>>mesh=Mesh(devices,('x','y'))>>>inp=np.arange(16).reshape(8,2)>>>arr=jax.device_put(inp,NamedSharding(mesh,P('x','y')))>>>out=jax.jit(lambdax:x*2)(arr)>>>assertout.sharding==NamedSharding(mesh,P('x','y'))
Contents

[8]ページ先頭

©2009-2025 Movatter.jp