Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Manual parallelism withshard_map#

Overview#

shard_map is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, orinstances, communicate with each other via explicit collective communication operations.

shard_map is complementary to, and composable with, the automatic compiler-based parallelization built intojit. Withjit you write code as if for a single device, andthe compiler can automatically partition computation over multiple devices, generating per-device code and communication collectives behind the scenes. Withshard_map you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.

If you’re familiar withpmap, think ofshard_map as an evolution. It’s more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, seea detailed comparison topmap.)

By reading this tutorial, you’ll learn how to useshard_map to get full control over your multi-device code. You’ll see in detail how it composes withjax.jit’s automatic parallelization andjax.grad’s automatic differentiation. We’ll also give some basic examples of neural network parallelization strategies, for a more detailed example seeThe Training Cookbook.

We’ll assume this tutorial is being run in an environment with eight devices:

importosos.environ["XLA_FLAGS"]='--xla_force_host_platform_device_count=8'# Use 8 CPU devices

So, let’s see ashard_map!#

Without further ado, here’s a toy example:

fromfunctoolsimportpartialimportjaximportjax.numpyasjnpfromjax.shardingimportMesh,PartitionSpecasPExplicit=jax.sharding.AxisType.ExplicitAuto=jax.sharding.AxisType.Auto
mesh=jax.make_mesh((4,2),('x','y'))jax.set_mesh(mesh)a=jax.device_put(jnp.arange(8*16.).reshape(8,16),P('x','y'))b=jax.device_put(jnp.arange(16*4.).reshape(16,4),P('y',None))@jax.shard_map(in_specs=(P('x','y'),P('y',None)),out_specs=P('x',None))defmatmul_basic(a_block,b_block):# a_block: f32[2, 8]# b_block: f32[8, 4]c_partialsum=jnp.dot(a_block,b_block)c_block=jax.lax.psum(c_partialsum,'y')# c_block: f32[2, 4]returnc_blockc=matmul_basic(a,b)# c: f32[8, 4]

This function computes a matrix multiply in parallel by performing local block matrix multiplies followed by a collective sum operation. We can check the result is correct:

fromjax.tree_utilimporttree_map,tree_alldefallclose(a,b):returntree_all(tree_map(partial(jnp.allclose,atol=1e-2,rtol=1e-2),a,b))allclose(c,jnp.dot(a,b,out_sharding=P('x',None)))
True

The result is sharded along its rows:

jax.debug.visualize_array_sharding(c)
  CPU 0,1  CPU 2,3  CPU 4,5  CPU 6,7

At a high level,shard_map is kind of likevmap orpmap, in that we’remapping a function over pieces of array data, but notice that

  • shard_map slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereasvmap would reduce the rank by mapping away an axis;

  • themesh argument lets us control precise device placement of computation and results;

  • we’re mapping over multiple data axes at once, and setting up multiple axis names for collectives (both'x' and'y' here);

  • since we’re not usingjax.jit yet, everything is eagerly evaluated, and we can evenprint intermediate values for debugging.

The above code is performing the same computation as thisjax.jit automatic parallelization code:

fromjax.shardingimportNamedShardinga=jax.device_put(a,P('x','y'))b=jax.device_put(b,P('y',None))@jax.jitdefmatmul_reference(a,b):returnjnp.dot(a,b,out_sharding=P('x',None))c_ref=matmul_reference(a,b)allclose(c_ref,jnp.dot(a,b,out_sharding=P('x',None)))
True

We can think ofshard_map as performing adevice_put orwith_sharding_constraint on its inputs according to itsmesh andin_specsarguments, so the blocks over whichmatmul_basic operates are the same as inmatmul_reference:

print('a blocks:');jax.debug.visualize_array_sharding(a)print('b blocks:');jax.debug.visualize_array_sharding(b)print('c blocks:');jax.debug.visualize_array_sharding(c)
a blocks:b blocks:c blocks:
          CPU 0          CPU 1          CPU 2          CPU 3          CPU 4          CPU 5          CPU 6          CPU 7
CPU 0,2,4,6CPU 1,3,5,7
  CPU 0,1  CPU 2,3  CPU 4,5  CPU 6,7

Slow down, start with the basics!#

Rank-reducing vs rank-preserving maps#

We can think ofvmap andpmap as unstacking each array input along an axis(e.g. unpacking a 2D matrix into its 1D rows), applying its body function toeach piece, and stacking the results back together, at least when collectivesaren’t involved:

defcheck_vmap(f,xs):ans=jax.vmap(f,in_axes=(0,),out_axes=0)(xs)expected=jnp.stack([f(x)forxinxs])# vmap reference semanticsprint(allclose(ans,expected))check_vmap(lambdax:x@x,jnp.arange(12).reshape(4,3))
True

For example, ifxs had shapef32[8,5] then eachx would have shapef32[5], and if eachf(x) had shapef32[3,7] then the final stacked resultvmap(f)(xs) would have shapef32[8,3,7]. That is, each application of thebody functionf takes as argument inputs with one fewer axis than thecorresponding argument tovmap(f). We can say these arerank-reducing mapswith unstacking/stacking of inputs/outputs.

The number of logical applications off, orinstances off, is determinedby the size of the input axis being mapped over: for example, if we map over aninput axis of size 8, semantically we get 8 logical applications of thefunction.

In contrast,shard_map does not have this rank-reducing behavior. Instead, wecan think of it as slicing (or “unconcatenating”) along input axes into blocks,applying the body function, and concatenating the results back together (againwhen collectives aren’t involved):

importnumpyasnpdevices=np.array(jax.devices()[:4])mesh=Mesh(devices,('i',))# mesh.shape['i'] = 4jax.set_mesh(mesh)defcheck_shmap(f,y):ans=jax.shard_map(f,in_specs=P('i'),out_specs=P('i'))(y)expected=jnp.concatenate([f(y_blk)fory_blkinjnp.split(y,mesh.shape['i'])])print(allclose(ans,expected))check_shmap(lambdax:x.T@x,jnp.arange(32).reshape(8,4))
True

Recall thatjnp.split slices its input into equally-sized blocks with the samerank, so that if in the above exampley had shapef32[8,5] then eachy_blk would have shapef32[2,5], and if eachf(y_blk) had shapef32[3,7] then the final concatenated resultshard_map(f,...)(y) would haveshapef32[12,7]. Soshard_map maps overshards, or blocks, of its inputs.We can say it’s arank-preserving map with unconcatenating/concatenating ofits inputs/outputs.

The number of logical applications off is determined by the mesh size, notby any input axis size: for example, if we have a mesh of total size 4 (i.e.over 4 devices) then semantically we get 4 logical applications of thefunction, corresponding to the 4 devices physically computing them.

Controlling how each input is split (unconcatenated) and tiled within_specs#

Each of thein_specs identifies some of the corresponding input array’s axeswith mesh axes by name usingPartitionSpecs, representing how to split (orunconcatenate) that input into the blocks to which the body function isapplied. That identification determines the shard sizes; when an input axis isidentified with a mesh axis, the input is split (unconcatenated) along thatlogical axis into a number of pieces equal to the corresponding mesh axis size.(It’s an error if the corresponding mesh axis size does not evenly divide theinput array axis size.) If an input’s pspec does not mention a mesh axis name,then there’s no splitting over that mesh axis. For example:

mesh=jax.make_mesh((4,2),('i','j'))jax.set_mesh(mesh)@jax.shard_map(in_specs=P('i',None),out_specs=P('i','j'))deff1(x_block):print(x_block.shape)# prints (3, 12)returnx_blockx1=jax.device_put(jnp.arange(12*12).reshape(12,12),P('i',None))y=f1(x1)
(3, 12)

Here, because the input pspec did not mention the mesh axis name'j', noinput array axis is split over that mesh axis; similarly, because the secondaxis of the input array is not identified with (and hence split over) any meshaxis, application off1 gets a full view of the input along that axis.

When a mesh axis is not mentioned in an input pspec, we can always rewrite to aless efficient program where all mesh axes are mentioned but the callerperforms ajnp.tile, for example:

@jax.shard_map(in_specs=P('i','j'),out_specs=P('i','j'))deff2(x_block):print(x_block.shape)returnx_blockx=jnp.arange(12*12).reshape(12,12)x_=jnp.tile(x,(1,mesh.shape['j']))# x_ has shape (12, 24)x_=jax.device_put(x,P('i','j'))y=f2(x_)# prints (3,12), and f1(x) == f2(x_)
(3, 6)

In other words, because each input pspec can mention each mesh axis name zeroor one times, rather than having to mention each name exactly once, we can saythat in addition to thejnp.split built into its input,shard_map also hasajnp.tile built into its input, at least logically (though the tiling maynot need to be carried out physically, depending on the arguments’ physicalsharding layout). The tiling to use is not unique; we could also have tiledalong the first axis, and used the pspecP(('j','i'),None).

Physical data movement is possible on inputs, as each device needs to have acopy of the appropriate data.

Controlling how each output assembled by concatenation, block transposition, and untiling usingout_specs#

Analogously to the input side, each of theout_specs identifies some of thecorresponding output array’s axes with mesh axes by name, representing how theoutput blocks (one for each application of the body function, or equivalentlyone for each physical device) should be assembled back together to form thefinal output value. For example, in both thef1 andf2 examples above theout_specs indicate we should form the final output by concatenating togetherthe block results along both axes, resulting in both cases an arrayy ofshape(12,24). (It’s an error if an output shape of the body function, i.e.an output block shape, has a rank too small for the concatenation described bythe corresponding output pspec.)

When a mesh axis name is not mentioned in an output pspec, it represents anun-tiling: when the user writes an output pspec which does not mention one ofthe mesh axis names, they promise that the output blocks are equal along thatmesh axis, and so only one block along that axis is used in the output (ratherthan concatenating all the blocks together along that mesh axis). For example,using the same mesh as above:

auto_mesh=jax.make_mesh((4,2),('i','j'),(Auto,Auto))withjax.set_mesh(auto_mesh):x=jnp.array([[3.]])z=jax.shard_map(lambda:x,in_specs=(),out_specs=P('i','j'))()print(z)# prints the same as jnp.tile(x, (4, 2))z=jax.shard_map(lambda:x,in_specs=(),out_specs=P('i',None))()print(z)# prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))z=jax.shard_map(lambda:x,in_specs=(),out_specs=P(None,None))()print(z)# prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.] [3. 3.] [3. 3.] [3. 3.]][[3.] [3.] [3.] [3.]][[3.]]

The body function closing over an array value is equivalent to passing it as anaugment with a corresponding input pspec of P(None, None). As another example,following more closely to the other examples above:

@jax.shard_map(in_specs=P('i','j'),out_specs=P('i',None))deff3(x_block):returnjax.lax.psum(x_block,'j')x=jax.device_put(jnp.arange(12*12).reshape(12,12),P('i','j'))y3=f3(x)print(y3.shape)
(12, 6)

The result has a second axis size of 6, half the size of the input’s secondaxis. In this case, the un-tile expressed by not mentioning the mesh axis name'j' in the output pspec was safe because of the collectivepsum, whichensures each output block is equal along the corresponding mesh axis. Here aretwo more examples where we vary which mesh axes are mentioned in the outputpspec:

@jax.shard_map(in_specs=P('i','j'),out_specs=P(None,'j'))deff4(x_block):returnjax.lax.psum(x_block,'i')x=jax.device_put(jnp.arange(12*12).reshape(12,12),P('i','j'))y4=f4(x)print(y4.shape)# (3,12)@jax.shard_map(in_specs=P('i','j'),out_specs=P(None,None))deff5(x_block):returnjax.lax.psum(x_block,('i','j'))y5=f5(x)print(y5.shape)# (3,6)
(3, 12)(3, 6)

On the physical side, not mentioning a mesh axis name in an output pspecassembles anArray from the output device buffers with replicated layoutalong that mesh axis.

There is no runtime check that the output blocks are actually equal along amesh axis to be un-tiled along, or equivalently that the corresponding physicalbuffers have equal values and thus can be interpreted as a replicated layoutfor a single logical array. But we can provide a static check mechanism whichraises an error on all potentially-incorrect programs.

Because theout_specs can mention mesh axis names zero or one times, andbecause they can be mentioned in any order, we can say that in addition to thejnp.concatenate built into its output,shard_map also has both anuntileand ablock transpose built into its output.

Physical data movement is not possible on outputs, no matter the output pspec.Instead,out_specs just encodes how to assemble the block outputs intoArrays, or physically how to interpret the buffers across devices as thephysical layout of a single logicalArray.

Tracking how values vary over manual mesh axes, andcheck_vma=True#

Under ashard_map, values can vary across function instances, or they can bethe same. For example, when we usein_specs to split an argument over a meshaxis, each function instance along that mesh axis gets a different value:

mesh=jax.make_mesh((2,),('i',))jax.set_mesh(mesh)@jax.shard_map(in_specs=P('i'),out_specs=P('i'))deff(x):print(x)return2*xx=jax.device_put(jnp.arange(6.),P('i'))f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[0. 1. 2.]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[3. 4. 5.]
Array([ 0.,  2.,  4.,  6.,  8., 10.], dtype=float32)

If insteadin_specs does not split the argument over a mesh axis, the valueis the same for each function instance along that axis:

@jax.shard_map(in_specs=P(),out_specs=P())deff(x):print(x)return2*xx=jnp.arange(6.)f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[0. 1. 2. 3. 4. 5.]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[0. 1. 2. 3. 4. 5.]
Array([ 0.,  2.,  4.,  6.,  8., 10.], dtype=float32)

A collective’s output may have a different variance than its input. Forexample, applying apsum produces the same output on each function instancealong an axis:

@jax.shard_map(in_specs=P('i'),out_specs=P())deff(x):y=jax.lax.psum(x,'i')print(y)returnyx=jax.device_put(jnp.arange(6.),P('i'))f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3. 5. 7.]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[3. 5. 7.]
Array([3., 5., 7.], dtype=float32)

In general, each intermediate value in ashard_map can be either unvarying orpossibly-varying over each manual mesh axis. That information can be tracked inthe JAX type system, enabled by thecheck_vma=True argument toshard_map:

@jax.shard_map(in_specs=P('i'),out_specs=P())deff(x):print(jax.typeof(x))# f32[3]{i}y=jax.lax.psum(x,'i')print(jax.typeof(y))# f32[3]returnyx=jax.device_put(jnp.arange(6.),P('i'))f(x)
float32[3]{V:i}float32[3]
Array([3., 5., 7.], dtype=float32)

Here, the typef32[3]{i} means that the value ofx is varying over meshaxis'i'. The type ofy printing asf32[3] indicates it is unvarying overall mesh axes; that is, empty sets are not printed. We call this part of thetype thevarying manual axes (VMA), and it can be accessed viajax.typeof(x).vma.

In general, the VMA type of a value can include any subset of the manual meshaxes over which theshard_map is acting:

mesh=jax.make_mesh((4,2),('i','j'))jax.set_mesh(mesh)@jax.shard_map(in_specs=P('i','j'),out_specs=P('i'))deff(x):print(jax.typeof(x))# f32[2,2]{i,j}y=jax.lax.psum(x,'j')assertjax.typeof(y).vma=={'i'}print(jax.typeof(y))# f32[2,2]{i}returnyx=jax.device_put(jnp.arange(8*4.).reshape(8,4),P('i','j'))f(x)
float32[2,2]{V:(i,j)}float32[2,2]{V:i}
Array([[ 2.,  4.],       [10., 12.],       [18., 20.],       [26., 28.],       [34., 36.],       [42., 44.],       [50., 52.],       [58., 60.]], dtype=float32)

Tracking varying manual axes can be useful:

  1. Your code can include prints, assertions, or conditionals about whethervalues are varying over expected mesh axes;

  2. It enables efficient reverse-mode autodiff that doesn’t require defensivepsums (seeJEP);

  3. The correctness ofout_specs can be checked, ruling out the potential bugexample below.

For example, thisout_specs bug is caught withcheck_vma=True, but uncaughtwithout it:

mesh=jax.make_mesh((2,),('i',))jax.set_mesh(mesh)x=jax.device_put(jnp.arange(6.),P('i'))try:y=jax.shard_map(lambdax:x,in_specs=P('i'),out_specs=P())(x)exceptExceptionase:print(e)
shard_map applied to the function '_rem_singleton' was given out_specs which require replication which can't be statically inferred given the mesh:The mesh given has shape (2,) with corresponding axis names ('i',).out_specs is P() which implies that the corresponding output value is replicated across mesh axis 'i', but could not infer replication over any axesCheck if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_vma=False argument to `jax.shard_map`.

Here theout_specs incorrectly promise that each function instance along meshaxis'i' produces the same value and thus we can choose just one of them.Withcheck_vma=True (the default) it raises an exception, while withcheck_vma=False there is no exception and instead we get silent undefinedbehavior.

Sometimes we want to treat a value that is unvarying over a mesh axis asvarying over that mesh axis. That’s whatjax.lax.pcast does:

@jax.shard_map(in_specs=P(),out_specs=None)deff(x):print(jax.typeof(x))# f32[6]y=jax.lax.pcast(x,'i',to='varying')print(jax.typeof(y))# f32[6]{i}x=jnp.arange(6.)f(x)
float32[6]float32[6]{V:i}

Think ofjax.lax.pcast(...,to='varying') as applying atype cast: it’s a no-op at runtime,though under reverse-mode autodiff it transposes to ajax.lax.psum (seeJEP). Thatmakes sense because they do opposite things to the VMA: wherey:f32[3]{i}=jax.lax.pcast(x:f32[3],'i',to='varying'),we correspondingly havex_grad:f32[3]=jax.lax.psum(y_grad:f32[3]{i},'i').

JAX implicitly insertsjax.lax.pcast(...,to='varying') calls in many cases,especially for binary operations:

@jax.shard_map(in_specs=(P('i'),P()),out_specs=P('i'))deff(x,y):returnx*yx=jax.device_put(jnp.arange(6.),P('i'))y=jnp.arange(3.)print(jax.make_jaxpr(f)(x,y))
{lambda; a:f32[6@i] b:f32[3].letc:f32[6@i] = shard_map[      check_vma=True      in_specs=(P('i',), P())      jaxpr={lambda; d:f32[3]{V:i} e:f32[3].letf:f32[3]{V:i} = pvary[axes=('i',)] e          g:f32[3]{V:i} = mul d fin(g,) }      manual_axes=frozenset({'i'})      mesh=AbstractMesh('i': 2, axis_types=(Explicit,), device_kind=cpu, num_cores=None)      out_specs=(P('i',),)    ] a bin(c,) }

In a jaxpr, the multiplication operation requires the VMA types of itsarguments to match, but for convenience thejax.numpy andjax.lax APIsautomatically applyjax.lax.pcast(...,to='varying') to make argument VMAtypes agree. In a jaxpr, thesejax.lax.pcast calls show up aspvary sincejax.lax.pcast(...,to='varying') dispatches tolax.pvary.

In some cases, like withjax.lax.scan, you might need to applyjax.lax.pcast yourself to ensure VMA types match as required. For example,this code raises an error:

mesh=jax.make_mesh((2,),('i',))jax.set_mesh(mesh)@jax.shard_map(in_specs=(P('i'),P()),out_specs=P('i'))deff(x,y):defbody(carry,_):c1,c2=carryreturn(c2,c1),()# swap the carry(x_,y_),_=jax.lax.scan(body,(x,y),(),length=2)returnx_,y_x=jnp.arange(6.)y=jnp.arange(3.)try:f(x,y)exceptExceptionase:print(e)
in_specs passed to shard_map: P('i',) does not match the specs of the input: P(None,) for arg: float32[6]. `in_specs` is an optional argument so you can omit specifying it and shard_map will infer the in_specs from the arguments. If you want to reshard your inputs, you can use `jax.reshard` on the arguments and then pass those args to shard_map.

To make the types match, we need to applyjax.lax.pcast to some arguments tothescan:

mesh=jax.make_mesh((2,),('i',))jax.set_mesh(mesh)@jax.shard_map(in_specs=(P('i'),P()),out_specs=P('i'))deff(x,y):defbody(carry,_):c1,c2=carryreturn(c2,c1),()# swap the carryy=jax.lax.pcast(y,'i',to='varying')# apply pcast to fix the error(x_,y_),_=jax.lax.scan(body,(x,y),(),length=2)returnx_,y_x=jax.device_put(jnp.arange(6.),P('i'))y=jnp.arange(3.)f(x,y)
(Array([0., 1., 2., 3., 4., 5.], dtype=float32), Array([0., 1., 2., 0., 1., 2.], dtype=float32))

Here’s a summary of collective primitives and how they affect varying manual axis types:

Name

Device variance type

Example

Lowers to HLO

Transpose

psum_invariant

Varying->Invariant

y:f32[3]{j}=psum(x:f32[3]{i,j},axis='i')

AllReduceSum (communication)

pvary

pvary

Invariant->Varying

y:f32[3]{i}=pvary(x:f32[3],'i')

no-op (no communication)

psum_invariant

all_to_all

Varying->Varying

y:f32[16]{i}=all_to_all(x:f32[16]{i},'i',0,0)AllToAll (communication)

all_to_all

axis_index

()->Varying

idx:i32[]{i}=axis_index('i')

ReplicaId and some arithmetic (no communication)

n/a

psum_scatter

Varying->Varying

y:f32[2]{i}=psum_scatter(x:f32[16]{i},'i')

ReduceScatterSum (communication)

all_gather

all_gather

Varying->Varying

y:f32[16]{i}=all_gather(x:f32[2]{i},'i')

AllGather (communication)

psum_scatter

pscatter

Invariant->Varying

y:f32[2]{i}=pscatter(x:f32[16],'i')

lambdax:x[axis_index('i'),None] (no communication)

all_gather_invariant

all_gather_invariant

Varying->Invariant

y:f32[16]=all_gather_invariant(x:f32[2]{i},'i')

AllGather (communication)

pscatter

A few notes on the table:

  • The functionjax.lax.psum is a convenience wrapper aroundpsum_invariant.

  • It’s surprising thatall_gather isVarying->Varying, but that’s becauseit’s really the transpose ofpsum_scatter which isVarying->Varying.

  • Neitherpscatter norall_gather_invariant have user APIs at the time ofwriting, but they’re described here for completeness.

API Specification#

fromjax.shardingimportMeshSpecs=PyTree[PartitionSpec]defshard_map(f:Callable,/,*,out_specs:Specs,mesh:Mesh|None=None,in_specs:Specs|None=None,axis_names:collections.abc.Set[AxisName]=set(),check_vma:bool=True,)->Callable:...

where:

  • communication collectives likepsum in the body off can mention the axis names ofmesh;

  • mesh encodes devices arranged in an array and with associated axis names, just like it does forsharding.NamedSharding; If None, mesh will be inferred from thecontext which can be set via thejax.set_mesh context manager.

  • in_specs arePartitionSpecs which can zero or one times mention axis names frommesh to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of typeExplicit, in which case the in_specs are inferred from the argument types;

  • out_specs arePartitionSpecs which can zero or one times mention axis names frommesh to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;

  • axis_names is an optional set of axis names corresponding to the subset of names ofmesh to treat manual in the body. If empty,f is manual over all axes of the mesh.

  • check_vma is an optional boolean indicating whether to check statically for any replication errors inout_specs, and also whether to enable a related automatic differentiation optimization (seeJEP).

The shapes of the arguments passed tof have the same ranks as the argumentspassed toshard_map-of-f, and the shape of an argument tof is computedfrom the shapeshape of the corresponding argument toshard_map-of-f andthe correspondingPartitionSpecspec as roughlytuple(sz//(1ifnisNoneelsemesh.shape[n])forsz,ninzip(shape,spec)).

Collectives tutorial#

Ashard_map need not be a pure map: function applications can communicatewith each other viacollectives, using axis names defined in themeshargument.

Recall thatshard_map maps a function over shards, or blocks, of input data,so that this:

mesh=Mesh(jax.devices(),('i',))x=jnp.arange(16.)f_shmapped=jax.shard_map(f,in_specs=P('i'),out_specs=P('i'))y=f_shmapped(x)

Computes the same values, evaluating applications off to the same argumentvalues, as this reference function:

deff_shmapped_ref(x):x_blocks=jnp.array_split(x,mesh.shape['i'])y_blocks=[f(x_blk)forx_blkinx_blocks]returnjnp.concatenate(y_blocks)

We call these applications off to different argument shardsfunctioninstances. Each function instance is executed on a different device (or subsetof devices).

These reference semantics work whenf has no communication collectives init. But what if we want the function instances to communicate, correspondingto having cross-device communication? That is, what are the referencesemantics whenf contains a collective? Sayf has just one collective, andis of the form

deff(x_blk):z_blk=f_part1(x_blk)u_blk=collective(z_blk,axis_name)v_blk=f_part2(x_blk,z_blk,u_blk)returnv_blk

where we’re assuming there’s only one mesh axis we’re mapping over, andaxis_name is the corresponding name for it. Then the reference semanticswould look more like:

deff_shmapped_ref(x):x_blocks=jnp.array_split(x,mesh.shape[0])z_blocks=[f_part1(x_blk)forx_blkinx_blocks]u_blocks=[collective_ref(i,z_blocks)foriinrange(len(z_blocks))]v_blocks=[f_part2(x_blk,z_blk,u_blk)forx_blk,z_blk,u_blkinzip(x_blocks,z_blocks,u_blocks)]returnjnp.concatenate(v_blocks)

Notice thatcollective_ref might depend on all thez_blocks. That is,whilef_part1 andf_part2 are mapped over blocks independently, acollective introduces some amount of cross-block dependence. Physically, thatmeans communication across devices. Exactly what communication happens, andwhat values are computed, depend on the collective.

psum#

The simplest collective may bejax.lax.psum, which computes anall-reduce-sum along a device mesh axis (or multiple axes).Here’s a toy example:

Illustration of a psum computation.
importjaximportjax.numpyasjnpfromjaximportlaxfromjax.shardingimportMesh,PartitionSpecasP
mesh1d=Mesh(jax.devices()[:4],('i',))jax.set_mesh(mesh1d)@jax.shard_map(mesh=mesh1d,in_specs=P('i'),out_specs=P(None))deff1(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.psum(x_block,'i')print('AFTER:\n',y_block)returny_block
x=jnp.array([3,1,4,1,5,9,2,6,5,3,5,8,9,7,1,2])y=f1(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 1 4 1]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[5 9 2 6]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5 3 5 8]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[9 7 1 2]AFTER:On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[22 20 12 17]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[22 20 12 17]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[22 20 12 17]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[22 20 12 17]FINAL RESULT: [22 20 12 17]

The prints show that each function application starts with its own chunk ofthe argument valuex_block. After thepsum, each function application hasthe same value ofy_block, computed by summing the applications’x_blockvalues together.

In the case where there’s a single axis name in the computation, we could saythat thecollective_ref reference implementation forpsum is

defpsum_ref(_,x_blocks):tot=sum(x_blocks)return[tot]*len(x_blocks)

Notice also that becausef1 returnsy_block, the result of apsum over'i', we can useout_specs=P() so the caller gets a single logical copy ofthe result value, rather than a tiled result.

When there is more than one mesh axis, we can perform apsum overeach one separately, or over multiple axes at once:

mesh2d=Mesh(np.array(jax.devices()[:4]).reshape(2,2),('i','j'))jax.set_mesh(mesh2d)@jax.shard_map(mesh=mesh2d,in_specs=P('i','j'),out_specs=P(None,'j'))deff2(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.psum(x_block,'i')print('AFTER:\n',y_block)returny_blocky=f2(jnp.arange(16).reshape(4,4))print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):[[0 1] [4 5]]On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):[[2 3] [6 7]]On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):[[ 8  9] [12 13]]On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):[[10 11] [14 15]]AFTER: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):[[ 8 10] [16 18]]On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):[[12 14] [20 22]]On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):[[ 8 10] [16 18]]On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):[[12 14] [20 22]]FINAL RESULT: [[ 8 10 12 14] [16 18 20 22]]

By applying apsum over mesh axis'i', we get values ofy_block whichare equal along axis ‘i', but not axis'j'. (So we can useout_specs=P(None,'j') to get a single logical result along that axis.)

If we apply thepsum over both axes, they_block value is equal along bothaxes:

@jax.shard_map(mesh=mesh2d,in_specs=P('i','j'),out_specs=P(None,None))deff3(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.psum(x_block,('i','j'))print('AFTER:\n',y_block)returny_blocky=f3(jnp.arange(16).reshape(4,4))print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):[[0 1] [4 5]]On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):[[2 3] [6 7]]On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):[[ 8  9] [12 13]]On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):[[10 11] [14 15]]AFTER:On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):[[20 24] [36 40]]On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):[[20 24] [36 40]]On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):[[20 24] [36 40]]On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):[[20 24] [36 40]]FINAL RESULT: [[20 24] [36 40]]

In machine learning, we often usepsum to compute total losses or, when wehave agrad inside theshard_mapped function body, total gradients.

In the sequel, we’ll see howpsum can be implemented in terms of otherprimitives, which gives some intuition about its communication cost.

all_gather#

Another fundamental operation is gathering array shards along an axis, so thateach function application has a full copy of the data along that axis:

Illustration of an all_gather computation.
jax.set_mesh(mesh1d)@jax.shard_map(mesh=mesh1d,in_specs=P('i'),out_specs=P('i'))deff4(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.all_gather(x_block,'i',tiled=True)print('AFTER:\n',y_block)returny_blockx=jnp.array([3,9,5,2])y=f4(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[9]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[2]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 9 5 2]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[3 9 5 2]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[3 9 5 2]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[3 9 5 2]FINAL RESULT: [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]

The prints show that each function application again starts with its own chunkof the argument valuex_block. After theall_gather, they have a commonvalue, computed by concatenating the values ofx_block.

(Notice that we actually can’t setout_specs=P() here. For technicalreasons related to automatic differentiation, we consider the output ofall_gather not to be guaranteed invariant across devices. If we wanted it tobe guaranteed invariant, we could usejax.lax.all_gather_invariant, or inthis case we could just avoid doing theall_gather in the function body andinstead just useout_specs=P('i') to perform the concatenation.)

Whentiled=False (the default), results are stacked along a new axis insteadof concatenated:

@jax.shard_map(mesh=mesh1d,in_specs=P('i'),out_specs=P('i'))deff5(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.all_gather(x_block,'i',tiled=False)print('AFTER:\n',y_block)returny_blocky=f5(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[9]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[2]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[[3] [9] [5] [2]]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[[3] [9] [5] [2]]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[[3] [9] [5] [2]]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[[3] [9] [5] [2]]FINAL RESULT: [[3] [9] [5] [2] [3] [9] [5] [2] [3] [9] [5] [2] [3] [9] [5] [2]]

We could write thecollective_ref reference semantics function forall_gather as

defall_gather_ref(_,x_blocks,*,tiled=False):combine=jnp.concatenateiftiledelsejnp.stackreturn[combine(x_blocks)]*len(x_blocks)

In deep learning, we might useall_gathers on parameters in fully shardeddata parallelism (FSDP).

psum_scatter#

Thejax.lax.psum_scatter collective is a bit less intuitive. It’s likepsum except each function instance gets only one shard of the result:

Illustration of a psum_scatter computation.
@jax.shard_map(in_specs=P('i'),out_specs=P('i'))deff6(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.psum_scatter(x_block,'i',tiled=True)print('AFTER:\n',y_block)returny_blockx=jnp.array([3,1,4,1,5,9,2,6,5,3,5,8,9,7,1,2])y=f6(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 1 4 1]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[5 9 2 6]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5 3 5 8]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[9 7 1 2]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[22]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[20]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[12]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[17]FINAL RESULT: [22 20 12 17]

As shown by the prints, each resultingy_block has a smaller size than theargumentx_block, unlike withpsum. Moreover, compared topsum, hereeachy_block only represents a slice of the sum of thex_blocks acrossfunction instances. (Even though each function instance gets only one shard ofthe sum, the final outputy is the same as in thepsum example becausehere we useout_specs=P('i') to concatenate each function instance’soutput.)

In terms of what values are computed, acollective_ref referenceimplementation might look like:

defpsum_scatter_ref(i,x_blocks,*,tiled=False):axis_size=len(x_blocks)tot=sum(x_blocks)iftiled:tot=tot.reshape(axis_size,-1,*tot.shape[1:])# split leading axisreturn[tot[i]foriinrange(tot.shape[0])]

It’s not captured in the semantics reference implementation, butpsum_scatter is useful because these results can be computed moreefficiently, with less communication, than a fullpsum. In fact, one way tothink ofpsum_scatter is as “the first half of apsum, before anall_gather”. That is, one way to implementpsum is:

defpsum(x,axis_name):summed_chunk=jax.lax.psum_scatter(x,axis_name)returnjax.lax.all_gather(summed_chunk,axis_name)

Indeed, this implementation is often used on both TPU and GPU!

The reasonpsum_scatter can require about half the communication as a fullpsum is illustrated in theppermute section.

Another intuition is that we can usepsum_scatter to implement a distributedmatrix multiplication with inputs and outputs sharded over the same axis. Inmachine learning,psum_scatter can be used in tensor-parallel matrixmultiplies or fully-sharded data parallel gradient accumulation, as shown inthe examples to follow.

ppermute#

Thejax.lax.ppermute collective provides the most direct way forfunction instances to send data to one another. Given a mesh axis and alist of(source_index,destination_index) pairs representing indices alongthat mesh axis,ppermute sends its argument value from each source functioninstance to each destination:

@jax.shard_map(in_specs=P('i'),out_specs=P('i'))deff7(x_block):sz=jax.lax.axis_size('i')print('BEFORE:\n',x_block)y_block=jax.lax.ppermute(x_block,'i',[(i,(i+1)%sz)foriinrange(sz)])print('AFTER:\n',y_block)returny_blocky=f7(jnp.arange(8))print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[0 1]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[2 3]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[4 5]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[6 7]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[6 7]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[0 1]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[2 3]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[4 5]FINAL RESULT: [6 7 0 1 2 3 4 5]

In this case, with just two function instances, each instance’s value ofy_block is the other’s value ofx_block.

Source indices and destination indices can’t be repeated. If an index does notappear as a destination, then the value of the corresponding functioninstance’s result is an array of zeros.

Acollective_ref reference implementation could look like

defppermute_ref(i,x_blocks,perm):results=[jnp.zeros_like(x_blocks[0])]*len(x_blocks)forsrc,dstinperm:results[dst]=x_blocks[src]returnresults

Other collectives can be implemented efficiently, in terms of totalcommunication, usingppermutes where each function passes data only to itsneighbors. For example, we could implementpsum_scatter using a sequence ofppermutes and local additions this way:

Illustration of a psum_scatter implementation.

Or, with a numerical example:

Illustration of a psum_scatter implementation.

Intuitively, on each iteration each function instance sends ‘up’ the value itreceived on the previous iteration, and reduces (adds) the value it receivesthis iteration. In code, it might look like this:

defpsum_scatter(x,axis_name,*,tiled=False):size=jax.lax.axis_size(axis_name)idx=jax.lax.axis_index(axis_name)# function instance index along axis_nameiftiled:x=x.reshape(size,-1,*x.shape[1:])# split leading axisshift=partial(jax.lax.ppermute,axis_name=axis_name,perm=[(i,(i-1)%size)foriinrange(size)])foriinrange(1,size):update=shift(x[(idx+i)%size])x=x.at[(idx+i+1)%size].add(update)returnx[idx]
@jax.shard_map(in_specs=P('i'),out_specs=P('i'))deff8(x_block):print('BEFORE:\n',x_block)y_block=psum_scatter(x_block,'i',tiled=True)print('AFTER:\n',y_block)returny_blockx=jnp.array([3,1,4,1,5,9,2,6,5,3,5,8,9,7,1,2])y=f8(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 1 4 1]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[5 9 2 6]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5 3 5 8]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[9 7 1 2]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[22]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[20]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[12]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[17]FINAL RESULT: [22 20 12 17]

On TPU, there are higher-dimensional variants of this algorithm to exploitmultiple bidirectional physical mesh axes.

Notice thatpsum_scatter is the transpose ofall_gather. Indeed, a way toimplementall_gather in terms ofppermute looks like the reverse of theabove process:

Illustration of an all_gather implementation.

In deep learning, we might useppermute when implementing SPMD pipelineparallelism, where we divide our network along its depth into stages andevaluate the applications of stages in parallel. Or we might useppermute inparallelizing the evaluation of convolutional layers, where we shard overspatial axes and thus devices must communicate “halos” to each other. Or itmay be used under-the-hood in tensor-parallel matrix multiplies.

all_to_all#

A final collective isall_to_all, which is essentially a block matrixtranspose operating along one positional axis and one cross-device axis:

Illustration of an all_to_all computation.
@jax.shard_map(mesh=mesh1d,in_specs=P('i'),out_specs=P('i'))deff9(x_block):print('BEFORE:\n',x_block)y_block=jax.lax.all_to_all(x_block,'i',split_axis=0,concat_axis=0,tiled=True)print('AFTER:\n',y_block)returny_blockx=jnp.array([3,1,4,1,5,9,2,6,5,3,5,8,9,7,1,2])y=f9(x)print('FINAL RESULT:\n',y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 1 4 1]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[5 9 2 6]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[5 3 5 8]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[9 7 1 2]AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,):[3 5 5 9]On TFRT_CPU_1 at mesh coordinates (i,) = (1,):[1 9 3 7]On TFRT_CPU_2 at mesh coordinates (i,) = (2,):[4 2 5 1]On TFRT_CPU_3 at mesh coordinates (i,) = (3,):[1 6 8 2]FINAL RESULT: [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]

Thesplit_axis argument indicates which positional axis should be shardedand partitioned across the mesh axis. Theconcat_axis argument indicates theaxis along which the communicated results should be concatenated or stacked.

Whentiled=False (the default), thesplit_axis axis size must equal thesize of the mesh axis namedaxis_name, and a new axis of that size iscreated at positionconcat_axis for the stacked results. Whentiled=True,thesplit_axis axis size need only be evenly divisible by the size of themesh axis, and results are concatenated along the existing axisconcat_axis.

Thecollective_ref reference semantics whensplit_axis=0 andconcat_axis=0 might look like:

defall_to_all_ref(_,x_blocks,*,tiled=False):axis_size=len(x_blocks)iftiled:splits=[jnp.array_split(x,axis_size)forxinx_blocks]return[jnp.concatenate(s)forsinzip(*splits)]else:splits=[list(x)forxinx_blocks]return[jnp.stack(s)forsinzip(*splits)]

In deep learning, we might useall_to_all in mixture-of-expert routing,where we first sort our local batch of examples according to which expert theyshould go to, then apply anall_to_all to redistribute examples to experts.

Toy examples#

How might we useshard_map and collective communication in practice? Theseexamples, while simple, give some idea.

Matrix multiplies#

Parallelizing matrix multiplication is central in scaling up deep learningmodels, both for training and for inference. Whenjax.jit automaticallyparallelizes matrix multiplication, it can use one of several differentstrategies, depending on matrix sizes, hardware details, and other factors. Howmight we write some of those parallelized routines more explicitly usingshard_map? And how can we optimize them to get better compute/communicationoverlap and thus improve FLOP utilization?

importjaximportjax.numpyasjnpfromjax.shardingimportMesh,PartitionSpecasP
mesh=Mesh(jax.devices()[:4],('i',))jax.set_mesh(mesh)defdevice_put(x,pspec):returnjax.device_put(x,NamedSharding(mesh,pspec))

Example 1:all-gather on one side#

Consider performing a matrix multiplication where we shard the left-hand sideargument (can think: parameters) on its leading (non-contracting) dimension:

lhs_spec=P('i',None)lhs=device_put(jax.random.normal(jax.random.key(0),(8,8)),lhs_spec)

And we shard the right-hand side argument (can think: activations) on itscontracting dimension, with a similar sharding for the output:

rhs_spec=P('i',None)rhs=device_put(jax.random.normal(jax.random.key(1),(8,4)),rhs_spec)

To perform this matrix multiplication, we can first all-gather the right-handside and then perform local matrix multiplies against the sharded left-handside:

@jax.jit@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_allgather(lhs_block,rhs_block):rhs=jax.lax.all_gather(rhs_block,'i',tiled=True)returnlhs_block@rhs
out=matmul_allgather(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True

That’s great, but we’re not getting any compute/communication overlaphere: before we can start the matmul, we need theall_gather to complete.Here’s a profile using the same code, but on larger example shapes ((8192,8192) forlhs and(8192,1024) forrhs):

Profile of an all-gather matmul without overlap.

We can get compute/communication overlap if instead of callingall_gather webasically inline our above implementation ofall_gather in terms ofppermute, then interleave steps of the gather permutation with local matrixmultiplies:

@jax.jit@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_allgather_overlapped(lhs_block,rhs_block):size=jax.lax.axis_size('i')idx=jax.lax.axis_index('i')shift=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i+1)%size)foriinrange(size)])B=lhs_block.shape[1]//sizelhs_blocks=lambdai:lax.dynamic_slice_in_dim(lhs_block,i*B,B,1)out_block=lhs_blocks(idx)@rhs_blockforiinrange(1,size):rhs_block=shift(rhs_block)out_block+=lhs_blocks((idx-i)%size)@rhs_blockreturnout_block
out=matmul_allgather_overlapped(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True

This implementation allows overlap between communication and computation, andalso avoids gathering a large intermediate onto each device. But on TPU it usesonly half the interconnect bandwidth by permuting in only one direction alongthe ring. To permute bidirectionally, we just split the blocks in half and sendeach half in each direction:

@jax.jit@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_allgather_overlapped_bidi(lhs_block,rhs_block):size=jax.lax.axis_size('i')idx=jax.lax.axis_index('i')shift_up=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i+1)%size)foriinrange(size)])shift_dn=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i-1)%size)foriinrange(size)])B=lhs_block.shape[1]//size//2# half-size blockslhs_blocks=lambdai,hi:lax.dynamic_slice_in_dim(lhs_block,(2*i+hi)*B,B,1)rhs_block_lo,rhs_block_hi=jnp.split(rhs_block,2,axis=0)out_block=lhs_blocks(idx,0)@rhs_block_loout_block+=lhs_blocks(idx,1)@rhs_block_hiforiinrange(1,size):rhs_block_lo=shift_up(rhs_block_lo)rhs_block_hi=shift_dn(rhs_block_hi)out_block+=lhs_blocks((idx-i)%size,0)@rhs_block_loout_block+=lhs_blocks((idx+i)%size,1)@rhs_block_hireturnout_block
out=matmul_allgather_overlapped_bidi(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True
Profile of an all-gather matmul with overlap.

In practice, to reduce compile times we would probably roll this into ajax.lax.fori_loop. We might also have additional axes of parallelisminvolved.

Example 2:psum_scatter the result#

Another sharding we might start with has bothlhs andrhs sharded alongtheir contracting dimensions, with the output sharded likerhs again:

lhs_spec=P(None,'i')lhs=device_put(lhs,lhs_spec)rhs_spec=P('i',None)rhs=device_put(rhs,rhs_spec)

Here we can use areduce_scatter to perform the contraction sum over shards:

@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_psumscatter(lhs_block,rhs_block):out_summand=lhs_block@rhs_blockreturnjax.lax.psum_scatter(out_summand,'i',tiled=True)out=matmul_psumscatter(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True

But the scattering communication must wait for the entire local matrix multiplyto finish before it can start. To get communication/computation overlap, we caninline an implementation ofpsum_scatter in terms ofppermute, theninterleave the communication steps with local matrix multiplies:

@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_psumscatter_overlapped(lhs_block,rhs_block):size=jax.lax.axis_size('i')idx=jax.lax.axis_index('i')shift=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i-1)%size)foriinrange(size)])lhs_block=lhs_block.reshape(size,-1,lhs_block.shape[1])# split 1st axisout_summand=lhs_block[(idx+1)%size]@rhs_blockforiinrange(1,size):out_summand=shift(out_summand)out_summand+=lhs_block[(idx+i+1)%size]@rhs_blockreturnout_summand
out=matmul_psumscatter_overlapped(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True

As in the previous example, to fully utilize interconnects on TPU, we’d run abidirectional version:

@jax.shard_map(in_specs=(lhs_spec,rhs_spec),out_specs=rhs_spec)defmatmul_psumscatter_overlapped_bidi(lhs_block,rhs_block):size=jax.lax.axis_size('i')idx=jax.lax.axis_index('i')shift_up=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i+1)%size)foriinrange(size)])shift_dn=partial(jax.lax.ppermute,axis_name='i',perm=[(i,(i-1)%size)foriinrange(size)])B=lhs_block.shape[0]//size//2# half-size blockslhs_blocks=lambdai,hi:lax.dynamic_slice_in_dim(lhs_block,(2*i+hi)*B,B,0)out_summand_lo=lhs_blocks((idx-1)%size,0)@rhs_blockout_summand_hi=lhs_blocks((idx+1)%size,1)@rhs_blockforiinrange(1,size):out_summand_lo=shift_up(out_summand_lo)out_summand_hi=shift_dn(out_summand_hi)out_summand_lo+=lhs_blocks((idx-i-1)%size,0)@rhs_blockout_summand_hi+=lhs_blocks((idx+i+1)%size,1)@rhs_blockreturnjnp.concatenate([out_summand_lo,out_summand_hi])
out=matmul_psumscatter_overlapped_bidi(lhs,rhs)print(jnp.allclose(out,lhs@rhs,atol=1e-3,rtol=1e-3))
True

Neural networks#

We can useshard_map to parallelize computation in neural networks, either byitself or in combination with the automatic partitioning injax.jit. Thissection has a few examples based on this toy neural network and random data:

importjaximportjax.numpyasjnpdefpredict(params,inputs):forW,binparams:outputs=jnp.dot(inputs,W)+binputs=jax.nn.relu(outputs)returnoutputsdefloss(params,batch):inputs,targets=batchpredictions=predict(params,inputs)returnjnp.mean(jnp.sum((predictions-targets)**2,axis=-1))
definit_layer(key,n_in,n_out):k1,k2=jax.random.split(key)W=jax.random.normal(k1,(n_in,n_out))/jnp.sqrt(n_in)b=jax.random.normal(k2,(n_out,))returnW,bdefinit(key,layer_sizes,batch_size):key,*keys=jax.random.split(key,len(layer_sizes))params=list(map(init_layer,keys,layer_sizes[:-1],layer_sizes[1:]))key,*keys=jax.random.split(key,3)inputs=jax.random.normal(keys[0],(batch_size,layer_sizes[0]))targets=jax.random.normal(keys[1],(batch_size,layer_sizes[-1]))returnparams,(inputs,targets)
layer_sizes=[784,128,128,128,128,128,8]batch_size=32params,batch=init(jax.random.key(0),layer_sizes,batch_size)

Compare these examples with the purelyautomatic partitioning examples in the“Distributed arrays and automatic partitioning”doc.While in those automatic partitioning examples we don’t need to edit the modelfunctions to use different parallelization strategies, withshard_map weoften do.

8-way batch data parallelism#

The simplest multi-device parallelism strategy is to shard the batch of inputsand targets over multiple devices, replicate the parameters over those devices,and apply the model in parallel to those shards of data. To evaluate the totalloss, the devices need only communicate with a scalar-sized all-reduce-sum atthe end. (To evaluate the gradient of the loss, the devices must performall-reduce-sums of parameter gradients in the backward pass.)

fromjax.shardingimportMesh,PartitionSpecasPmesh=jax.make_mesh((8,),('batch',))jax.set_mesh(mesh)# replicate initial params on all devices, shard data batch over devicesbatch=jax.device_put(batch,NamedSharding(mesh,P('batch')))params=jax.device_put(params,NamedSharding(mesh,P()))# adapt the loss function to sum the losses across devices@jax.shard_map(out_specs=P())defloss_dp(params,local_batch):inputs,targets=local_batchpredictions=predict(params,inputs)# use reference 'predict`local_loss=jnp.mean(jnp.sum((predictions-targets)**2,axis=-1))returnjax.lax.pmean(local_loss,'batch')

We can check that the loss and its gradients match the reference (base) model:

print(jax.jit(loss)(params,batch))print(jax.jit(loss_dp)(params,batch))
11.920311.9203
defallclose(a,b):returntree_all(tree_map(partial(jnp.allclose,atol=1e-2,rtol=1e-2),a,b))print(allclose(jax.jit(jax.grad(loss))(params,batch),jax.jit(jax.grad(loss_dp))(params,batch)))
True

We can print the compiler IR to inspect the gradient computation and verifythat the collective all-reduce-sum operations happen where we’d expect: at theend of the forward pass to compute the loss value, and in the backward pass tocompute the total parameter gradients.

8-way fully sharded data parallelism (FSDP)#

Another strategy is to additionally shard the parameters over the devices,all-gathering each one when the full value is needed for thejnp.dot or biasaddition. Since we only have one full parameter in local device memory at atime, rather than keeping all parameters in all device memories as in thepreceding DP example, we free up significant memory that we can use for largermodels or larger batch sizes. And because XLA will overlap computation andinter-device communication, the wall-clock time doesn’t suffer.

So now we need collectives in two places: the model prediction functionpredict needs to all-gather the parameters before they’re used, and as in theDP case the loss function needs to sum the local losses to compute the totalloss.

There’s one other ingredient we need: we don’t want to store the fully gatheredparameters from the forward pass for use on the backward pass. Instead, we wantto gather them again on the backward pass. We can express that by usingjax.remat with acustompolicy(or acustom_vjp), though XLA typically does that rematerializationautomatically.

This generalFSDPapproach is similartoweight update sharding (WUS) andZeRO-3.

# shard data batch *and params* over devicesmesh=jax.make_mesh((4,),('batch',))jax.set_mesh(mesh)batch=jax.device_put(batch,P('batch'))params=jax.device_put(params,P('batch'))# adapt the prediction function to gather weights just before their use,# and to re-gather them on the backward pass (rather than saving them)@partial(jax.remat,policy=lambdaop,*_,**__:str(op)!='all_gather')defpredict_fsdp(params_frag,inputs):forW_frag,b_fraginparams_frag:W=jax.lax.all_gather(W_frag,'batch',tiled=True)b=jax.lax.all_gather(b_frag,'batch',tiled=True)outputs=jnp.dot(inputs,W)+binputs=jax.nn.relu(outputs)returnoutputs@jax.shard_map(out_specs=P())defloss_fsdp(local_params,local_batch):inputs,targets=local_batchpredictions=predict_fsdp(local_params,inputs)local_loss=jnp.mean(jnp.sum((predictions-targets)**2,axis=-1))returnjax.lax.pmean(local_loss,'batch')

Again we can check that the loss and its gradients match the reference model:

repl_params=jax.device_put(params,P())repl_batch=jax.device_put(batch,P())print(jax.jit(loss)(repl_params,repl_batch))print(jax.jit(loss_fsdp)(params,batch))print(allclose(jax.jit(jax.grad(loss))(repl_params,repl_batch),jax.jit(jax.grad(loss_fsdp))(params,batch)))
11.92029811.920298True

8-way tensor parallelism (TP)#

Usually we don’t use tensor model parallelism by itself, but seeing it inisolation is a good warmup on parallel matrix multiplication. It’s also a goodexample of usingshard_map in a library function, called in a largerjit-based computation.

The parallelization idea is that we’ll keep the data/activations sharded overits feature axis (rather than its batch axis), and we’ll similarly shard weightmatrices over their input-feature axis (and biases over their feature axis).Then to perform the parallel matrix multiplication, we’ll perform local matrixmultiplications followed by apsum_scatter to sum the local results andefficiently scatter the result’s shards.

mesh=jax.make_mesh((8,),('feats',))jax.set_mesh(mesh)batch=jax.device_put(batch,NamedSharding(mesh,P(None,'feats')))params=jax.device_put(params,NamedSharding(mesh,P('feats')))defpredict_tp(params,inputs):forW,binparams:outputs=gemm_tp(inputs,W,b)inputs=jax.nn.relu(outputs)returnoutputs@jax.shard_map(in_specs=(P(None,'feats'),P('feats',None),P('feats')),out_specs=P(None,'feats'))defgemm_tp(inputs,W,b):block_result=jnp.dot(inputs,W)returnjax.lax.psum_scatter(block_result,'feats',scatter_dimension=1,tiled=True)+bdefloss_tp(params,batch):inputs,targets=batchpredictions=predict_tp(params,inputs)returnjnp.mean(jnp.sum((predictions-targets)**2,axis=-1))# NOTE psum!

FSDP + TP, withshard_map at the top level#

We can compose these strategies together, using multiple axes of parallelism.

mesh=jax.make_mesh((4,2),('batch','feats'))jax.set_mesh(mesh)batch=jax.device_put(batch,NamedSharding(mesh,P('batch','feats')))params=jax.device_put(params,NamedSharding(mesh,P(('feats','batch'))))# mostly same as previous predict_fsdp definition, except we call gemm_tp@partial(jax.remat,policy=lambdaop,*_,**__:str(op)!='all_gather')defpredict_fsdp_tp(params_frag,inputs):forW_frag,b_fraginparams_frag:W=jax.lax.all_gather(W_frag,'batch',tiled=True)b=jax.lax.all_gather(b_frag,'batch',tiled=True)block_result=jnp.dot(inputs,W)outputs=jax.lax.psum_scatter(block_result,'feats',scatter_dimension=1,tiled=True)+binputs=jax.nn.relu(outputs)returnoutputs@jax.shard_map(in_specs=(P(('feats','batch')),P('batch','feats')),out_specs=P())defloss_fsdp_tp(local_params,local_batch):inputs,targets=local_batchpredictions=predict_fsdp_tp(local_params,inputs)sq_err=jax.lax.psum(jnp.sum((predictions-targets)**2,axis=-1),'feats')returnjax.lax.pmean(jnp.mean(sq_err),'batch')

Notice how we have to dotwo collective reductions: one over'feats' andone over'batch'. In the pure TP example, we didn’t write the'feats'reduction explicitly because we only usedshard_map withingemm_tp; in thecallerloss_tp, the compiler automatically translated our use ofjnp.sum toperform apsum as needed given the sharded result returned bypredict_tp.

repl_params=jax.device_put(params,P())repl_batch=jax.device_put(batch,P())print(jax.jit(loss)(repl_params,repl_batch))print(jax.jit(loss_fsdp_tp)(params,batch))print(allclose(jax.jit(jax.grad(loss))(repl_params,repl_batch),jax.jit(jax.grad(loss_fsdp_tp))(params,batch)))
11.92029811.920298True

SPMD pipeline parallelism (PP)#

With pipeline parallelism we aim to parallelize the evaluation of layers atdifferent depths in our network. For example, one device might compute theapplication of the first layer while another device computes the application ofthe second; when they finish, the first device passes its results to the secondwhile the second passes its results to the device responsible for the thirdlayer, and the process repeats. In general the number of pipeline stages may bedifferent from the number of layers, as each stage may be responsible formultiple layers.

With SPMD pipelining, we exploit the fact that most layers in the network applythe computation, just with different parameter values. In particular, we canstack together all the parameters except for those for the first and lastlayers, then use ashard_map to map over blocks of those layer parameters,where each block of parameters corresponds to a pipeline stage. We then use thejax.lax.ppermute collective to shift data down the parallel pipeline.

This particular pipelining strategy is essentiallythe GPipestrategy. There are several variants, aswell as quite different strategies, and which is appropriate can depend on thespeed of the networking between stages and batch sizes. But for this tutorialwe’ll focus on just one strategy.

First, we choose some pipeline parameters:

L=len(params)-2# num layers, excluding first and lastN=batch_size# batch sizeF=params[0][0].shape[1]# num features# choose some pipeline parametersS=2# number of stagesB=8# size of each microbatchassertL%S==0,"S (number of stages) must divide L (number of inner layers)"# compute some useful quantitiesM,ragged=divmod(N,B)# M is number of microbatchesassertnotragged,"B (size of each microbatch) must divide total batch size"K,ragged=divmod(M,S)# K is microbatches per stageassertnotragged,"S (number of stages) must divide number of microbatches"print(f'{S} stages,{L//S} layer(s) per stage,{L} pipelined layers total')print(f'{B} examples per microbatch,{M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total8 examples per microbatch, 4 microbatches total
mesh=Mesh(jax.devices()[:S],('stages',))defpredict_pp(params,inputs):(W_first,b_first),inner_params,(W_last,b_last)=paramsinputs=jax.nn.relu(jnp.dot(inputs,W_first)+b_first)inputs=spmd_pipeline(lambdaWb,x:jax.nn.relu(x@Wb[0]+Wb[1]),inner_params,inputs)outputs=jnp.dot(inputs,W_last)+b_lastreturnoutputs@jax.shard_map(in_specs=((P(),P('stages'),P()),P('stages')),out_specs=P())defloss_pp(params,batch):inputs,targets=batchpredictions=predict_pp(params,inputs.reshape(K,B,-1)).reshape(K*B,-1)local_loss=jnp.mean(jnp.sum((predictions-targets)**2,axis=-1))returnjax.lax.pmean(local_loss,'stages')
defspmd_pipeline(fn,stage_params,inputs):stage=jax.lax.axis_index('stages')outputs=jnp.zeros_like(inputs)*jnp.nanstate=jnp.zeros((L//S,B,F))*jnp.nanforiinrange(M+L-1):state=state.at[0].set(jnp.where(stage==0,inputs[i%K],state[0]))state=jax.vmap(fn)(stage_params,state)outputs=outputs.at[(i-L+1)%K].set(jnp.where(stage==S-1,state[-1],outputs[(i-L+1)%K]))state,inputs,outputs=shift(i,state,inputs,outputs)outputs=jax.lax.ppermute(outputs,'stages',[(i,(i+1)%S)foriinrange(S)])returnoutputsdefshift(i,state,inputs,outputs):sh=lambdax,d:jax.lax.ppermute(x,'stages',[(i,(i+d)%S)foriinrange(S)])state=jnp.roll(state,+1,axis=0).at[0].set(sh(state[-1],+1))if(i%K)==(-1%K):inputs=sh(inputs,+1)if((i-L+1)%K)==(-1%K):outputs=sh(outputs,+1)returnstate,inputs,outputs
first_params,*inner_params,last_params=paramsWs,bs=zip(*inner_params)params_stacked=jnp.stack(Ws),jnp.stack(bs)first_params=jax.device_put(first_params,NamedSharding(mesh,P()))params_stacked=jax.device_put(params_stacked,NamedSharding(mesh,P('stages')))last_params=jax.device_put(last_params,NamedSharding(mesh,P()))params_=first_params,params_stacked,last_paramsbatch_=jax.device_put(batch,NamedSharding(mesh,P('stages')))
jax.set_mesh(mesh)print(jax.jit(loss_pp)(params_,batch_))
11.920298
_=jax.jit(jax.grad(loss_pp))(params_,batch_)# don't crash

[8]ページ先頭

©2009-2026 Movatter.jp