jax.sharding module
Contents
jax.sharding module#
Classes#
- classjax.sharding.Sharding(*args,**kwargs)#
Describes how a
jax.Arrayis laid out across devices.- propertyaddressable_devices:set[Device]#
The set of devices in the
Shardingthat are addressable by thecurrent process.
- addressable_devices_indices_map(global_shape)[source]#
A mapping from addressable devices to the slice of array data each contains.
addressable_devices_indices_mapcontains that part ofdevice_indices_mapthat applies to the addressable devices.- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index | None]
- propertydevice_set:set[Device][source]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- is_equivalent_to(other,ndim)[source]#
Returns
Trueif two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards onthe same devices.
- propertyis_fully_addressable:bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all ofthe devices named in the
Sharding.is_fully_addressableisequivalent to “is_local” in multi-process JAX.
- propertyis_fully_replicated:bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of theentire data.
- classjax.sharding.SingleDeviceSharding(*args,**kwargs)#
Bases:
ShardingA
Shardingthat places its data on a single device.- Parameters:
device – A single
Device.
Examples
>>>single_device_sharding=jax.sharding.SingleDeviceSharding(...jax.devices()[0])
- propertydevice_set:set[Device][source]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- propertyis_fully_addressable:bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all ofthe devices named in the
Sharding.is_fully_addressableisequivalent to “is_local” in multi-process JAX.
- propertyis_fully_replicated:bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of theentire data.
- classjax.sharding.NamedSharding(*args,**kwargs)#
Bases:
ShardingA
NamedShardingexpresses sharding using named axes.A
NamedShardingis a pair of aMeshof devices andPartitionSpecwhich describes how to shard an array across thatmesh.A
Meshis a multidimensional NumPy array of JAX devices,where each axis of the mesh has a name, e.g.'x'or'y'.A
PartitionSpecis a tuple, whose elements can be aNone,a mesh axis, or a tuple of mesh axes. Each element describes how an inputdimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x','y')says that the first dimension of datais sharded acrossxaxis of the mesh, and the second dimension is shardedacrossyaxis of the mesh.TheDistributed arrays and automatic parallelizationandExplicit Sharding tutorials have more details and diagrams thatexplain how
MeshandPartitionSpecare used.- Parameters:
mesh – A
jax.sharding.Meshobject.spec – A
jax.sharding.PartitionSpecobject.
Examples
>>>fromjax.shardingimportMesh>>>fromjax.shardingimportPartitionSpecasP>>>mesh=Mesh(np.array(jax.devices()).reshape(2,4),('x','y'))>>>spec=P('x','y')>>>named_sharding=jax.sharding.NamedSharding(mesh,spec)
- propertyaddressable_devices:set[Device][source]#
The set of devices in the
Shardingthat are addressable by thecurrent process.
- propertydevice_set:set[Device][source]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.
- propertyis_fully_addressable:bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all ofthe devices named in the
Sharding.is_fully_addressableisequivalent to “is_local” in multi-process JAX.
- propertyis_fully_replicated:bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of theentire data.
- propertymesh#
(self) -> object
- propertyspec#
(self) -> jax::PartitionSpec
- classjax.sharding.PmapSharding(*args,**kwargs)#
Bases:
ShardingDescribes a sharding used by
jax.pmap().- classmethoddefault(shape,sharded_dim=0,devices=None)[source]#
Creates a
PmapShardingwhich matches the default placementused byjax.pmap().- Parameters:
shape (Shape) – The shape of the input array.
sharded_dim (int |None) – Dimension the input array is sharded on. Defaults to 0.
devices (Sequence[xc.Device]|None) – Optional sequence of devices to use. If omitted, the implicitdevice order used by pmap is used, which is the order of
jax.local_devices().
- Return type:
- propertydevice_set:set[Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includesnon-addressable devices from other processes.
- propertydevices#
(self) -> numpy.ndarray
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., includingnon-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- is_equivalent_to(other,ndim)[source]#
Returns
Trueif two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards onthe same devices.
- Parameters:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- Return type:
- propertyis_fully_addressable:bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all ofthe devices named in the
Sharding.is_fully_addressableisequivalent to “is_local” in multi-process JAX.
- propertyis_fully_replicated:bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of theentire data.
- shard_shape(global_shape)[source]#
Returns the shape of the data on each device.
The shard shape returned by this function is calculated from
global_shapeand the properties of the sharding.- Parameters:
global_shape (Shape)
- Return type:
Shape
- propertysharding_spec#
(self) -> jax::ShardingSpec
- classjax.sharding.PartitionSpec(*args,**kwargs)#
Tuple describing how to partition an array across a mesh of devices.
Each element is either
None, a string, or a tuple of strings.See the documentation ofjax.sharding.NamedShardingfor more details.This class exists so JAX’s pytree utilities can distinguish a partitionspecifications from tuples that should be treated as pytrees.
- propertyreduced#
(self) -> frozenset
- propertyunreduced#
(self) -> frozenset
- classjax.sharding.Mesh(devices,axis_names,axis_types=None)[source]#
Declare the hardware resources available in the scope of this manager.
SeeDistributed arrays and automatic parallelization andExplicit Sharding tutorials.
- Parameters:
devices (np.ndarray) – A NumPy ndarray object containing JAX device objects (asobtained e.g. from
jax.devices()).axis_names (tuple[MeshAxisName,...]) – A sequence of resource axis names to be assigned to thedimensions of the
devicesargument. Its length should match therank ofdevices.axis_types (tuple[AxisType,...]) – and optional tuple of
jax.sharding.AxisTypeentries corresponding totheaxis_names. SeeExplicit Sharding for more information.
Examples
>>>fromjax.shardingimportMesh>>>fromjax.shardingimportPartitionSpecasP,NamedSharding>>>importnumpyasnp...>>># Declare a 2D mesh with axes `x` and `y`.>>>devices=np.array(jax.devices()).reshape(4,2)>>>mesh=Mesh(devices,('x','y'))>>>inp=np.arange(16).reshape(8,2)>>>arr=jax.device_put(inp,NamedSharding(mesh,P('x','y')))>>>out=jax.jit(lambdax:x*2)(arr)>>>assertout.sharding==NamedSharding(mesh,P('x','y'))
