Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.make_mesh

Contents

jax.make_mesh#

jax.make_mesh(axis_shapes,axis_names,axis_types=None,*,devices=None)[source]#

Creates an efficient mesh with the shape and axis names specified.

This function attempts to automatically compute a good mapping from a set oflogical axes to a physical mesh. For example, on a TPU v3 with 8 devices:

>>>mesh=jax.make_mesh((8,),('x'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 6, 7, 4, 5]

The above ordering takes into account the physical topology of TPU v3.It orders the devices into a ring, which yields efficient all-reduces on aTPU v3.

Now, let’s see another example with 16 devices of TPU v3:

>>>mesh=jax.make_mesh((2,8),('x','y'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13]>>>mesh=jax.make_mesh((4,4),('x','y'))>>>[d.idfordinmesh.devices.flat][0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

As you can see, logical axes (axis_shapes) affect the ordering of thedevices.

You can usejax.experimental.mesh_utils.create_device_mesh if you want touse the extra arguments it provides likecontiguous_submeshes andallow_split_physical_axes.

Parameters:
  • axis_shapes (Sequence[int]) – Shape of the mesh. For example, axis_shape=(4, 2)

  • axis_names (Sequence[str]) – Names of the mesh axes. For example, axis_names=(‘x’, ‘y’)

  • axis_types (tuple[mesh_lib.AxisType,...]|None) – Optional tuple ofjax.sharding.AxisType entriescorresponding to theaxis_names. SeeExplicit Sharding for moreinformation.

  • devices (Sequence[xc.Device]|None) – Optional keyword only argument, that allows you to specify thedevices you want to create a mesh with.

Returns:

Ajax.sharding.Mesh object.

Return type:

mesh_lib.Mesh

Contents

[8]ページ先頭

©2009-2025 Movatter.jp