jax.make_mesh
Contents
jax.make_mesh#
- jax.make_mesh(axis_shapes,axis_names,axis_types=None,*,devices=None)[source]#
Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set oflogical axes to a physical mesh. For example, on a TPU v3 with 8 devices:
>>>mesh=jax.make_mesh((8,),('x'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 6, 7, 4, 5]
The above ordering takes into account the physical topology of TPU v3.It orders the devices into a ring, which yields efficient all-reduces on aTPU v3.
Now, let’s see another example with 16 devices of TPU v3:
>>>mesh=jax.make_mesh((2,8),('x','y'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13]>>>mesh=jax.make_mesh((4,4),('x','y'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
As you can see, logical axes (axis_shapes) affect the ordering of thedevices.
You can usejax.experimental.mesh_utils.create_device_mesh if you want touse the extra arguments it provides likecontiguous_submeshes andallow_split_physical_axes.
- Parameters:
axis_shapes (Sequence[int]) – Shape of the mesh. For example, axis_shape=(4, 2)
axis_names (Sequence[str]) – Names of the mesh axes. For example, axis_names=(‘x’, ‘y’)
axis_types (tuple[mesh_lib.AxisType,...]|None) – Optional tuple of
jax.sharding.AxisTypeentriescorresponding to theaxis_names. SeeExplicit Sharding for moreinformation.devices (Sequence[xc.Device]|None) – Optional keyword only argument, that allows you to specify thedevices you want to create a mesh with.
- Returns:
A
jax.sharding.Meshobject.- Return type:
mesh_lib.Mesh
