shmap (shard_map) for simple per-device code
Contents
shmap (shard_map) for simple per-device code#
sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@
January 2023
This was the design doc proposingshard_map. You may instead wantthe up-to-date user docs.
Motivation#
JAX supports two schools of thought for multi-device programming:
Compiler, take the wheel! Let the compiler automatically partition bulkarray functions over devices.
Just let me write what I mean, damnit! Give me per-device code andexplicit communication collectives.
We need great APIs for both, and rather than being mutually exclusivealternatives, they need to compose with each other.
Withpjit (now justjit) we havea next-genAPIfor the first school. But we haven’t quite leveled-up the second school.pmapfollows the second school, but over time we found it hasfatalflaws.xmap solved those flaws,but it doesn’t quite give us per-device shapes, and it includes several otherbig ideas too. Meanwhile, new demands for per-device explicit-collectivesprogramming have emerged, like inEfficiently Scaling TransformerInference.
We can level-up the second school withshmap.shmap is:
a simple multi-device parallelism API which lets us write per-device code withexplicit collectives, where logical shapes match per-device physical buffershapes and collectives correspond exactly to cross-device communication;
a specialization of
xmapwith scaled-back features and a few tweaks;a fairly direct surfacing of the XLA SPMD Partitioner’s ‘manual’ mode;
a fun-to-say Seussian name which could stand for
shard_map,shpecialized_xmap,sholto_map, orsharad_map.
Forpjit users,shmap is a complementary tool. It can be used inside apjit computation to drop temporarily into a “manual collectives” mode, like anescape hatch from the compiler’s automatic partitioning. That way, users get theconvenience and familiar just-NumPy programming model ofpjit for most of theircode, along with the ability to hand-optimize collective communication withshmap wherever it’s needed. It’s the best of both worlds!
Forpmap users,shmap is a strict upgrade. It’s more expressive,performant, and composable with other JAX APIs, without making basic batch dataparallelism any harder.
For more on practical use, you can jump toWhen should you useshmap and whenshould you usepjit?.If you’re wondering why we need a new thing at all, or whatthe problems withpmap are, jump toWhy don’tpmap orxmap already solvethis?.Or keep reading the next section to see someshmap examples and the API spec.
So, let’s seeshmap!#
TL;DR example (with a more detailed explanation to follow)#
Sho shick:
fromfunctoolsimportpartialimportnumpyasnpimportjaximportjax.numpyasjnpfromjax.shardingimportMesh,PartitionSpecasPfromjax.experimental.shard_mapimportshard_mapmesh=jax.make_mesh((4,2),('i','j'))a=jnp.arange(8*16.).reshape(8,16)b=jnp.arange(16*32.).reshape(16,32)@partial(shard_map,mesh=mesh,in_specs=(P('i','j'),P('j',None)),out_specs=P('i',None))defmatmul_basic(a_block,b_block):# a_block: f32[2, 8]# b_block: f32[8, 32]z_partialsum=jnp.dot(a_block,b_block)z_block=jax.lax.psum(z_partialsum,'j')returnz_blockc=matmul_basic(a,b)# c: f32[8, 32]
Notice:
no nesting needed (or
axis_index_groups) for multiple axes of parallelism,unlikepmap;no reshapes in the caller, unlike
pmapand hard-xmap, and logical shapescorrespond to per-device physical shapes, unlike (non-hard)xmap;precise device placement control by using
mesh, unlikepmap;there’s only one set of axis names for logical and physical, unlike
xmap;the result is a
jax.Arraywhich could be efficiently passed to apjit,unlikepmap;this same code works efficiently inside a
pjit/jit, unlikepmap;this code works eagerly, so we can
pdbin the middle and print values,unlikexmap’s current implementation (though by designxmapwithout thesequential schedule can in principle work eagerly too).
Here’s another matmul variant with a fully sharded result:
@partial(shard_map,mesh=mesh,in_specs=(P('i','j'),P('j',None)),out_specs=P('i','j'))defmatmul_reduce_scatter(a_block,b_block):# c_partialsum: f32[8/X, 32]c_partialsum=jnp.matmul(a_block,b_block)# c_block: f32[8/X, 32/Y]c_block=jax.lax.psum_scatter(c_partialsum,'j',scatter_dimension=1,tiled=True)returnc_blockc=matmul_reduce_scatter(a,b)
Slow down, start with the basics!#
Rank-reducing vs rank-preserving maps over array axes#
We can think ofpmap (andvmap andxmap) as unstacking each array inputalong an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its bodyfunction to each piece, and stacking the results back together, at least whencollectives aren’t involved:
pmap(f,in_axes=[0],out_axes=0)(xs)==jnp.stack([f(x)forxinxs])
For example, ifxs had shapef32[8,5] then eachx has shapef32[5], andif eachf(x) has shapef32[3,7] then the final stacked resultpmap(f)(xs)has shapef32[8,3,7]. That is, each application of the body functionf takesas argument inputs with one fewer axis than the corresponding argument topmap(f). We can say these arerank-reducing maps with unstacking/stacking ofinputs/outputs.
The number of logical applications off is determined by the size of the inputaxis being mapped over: for example, if we map over an input axis of size 8,semantically we get 8 logical applications of the function, which for pmapalways correspond to 8 devices physically computing them.
In contrast,shmap does not have this rank-reducing behavior. Instead, we canthink of it as slicing (or “unconcatenating”) along input axes into blocks,applying the body function, and concatenating the results back together (againwhen collectives aren’t involved):
devices=np.array(jax.devices()[:4])m=Mesh(devices,('i',))# mesh.shape['i'] = 4shard_map(f,m,in_specs=P('i'),out_specs=P('i'))(y)==jnp.concatenate([f(y_blk)fory_blkinjnp.split(y,4)])
Recall thatjnp.split slices its input into equally-sized blocks with the samerank, so that if in the above exampley has shapef32[8,5] then eachy_blkhas shapef32[2,5], and if eachf(y_blk) has shapef32[3,7] then the finalconcatenated resultshard_map(f,...)(y) has shapef32[12,7]. Soshmap(shard_map) maps over shards, or blocks, of its inputs. We can say it’s arank-preserving map with unconcatenating/concatenating of its inputs/outputs.
The number of logical applications off is determined by the mesh size, not byany input axis size: for example, if we have a mesh of total size 4 (i.e. over 4devices) then semantically we get 4 logical applications of the function,corresponding to the 4 devices physically computing them.
Controlling how each input is split (unconcatenated) and tiled within_specs#
Each of thein_specs identifies some of the corresponding input array’s axeswith mesh axes by name usingPartitionSpecs, representing how to split (orunconcatenate) that input into the blocks to which the body function is applied.That identification determines the shard sizes; when an input axis is identifiedwith a mesh axis, the input is split (unconcatenated) along that logical axisinto a number of pieces equal to the corresponding mesh axis size. (It’s anerror if the corresponding mesh axis size does not evenly divide the input arrayaxis size.) If an input’s pspec does not mention a mesh axis name, then there’sno splitting over that mesh axis. For example:
devices=np.array(jax.devices())m=Mesh(devices.reshape(4,2),('i','j'))@partial(shard_map,mesh=m,in_specs=P('i',None),out_specs=P('i','j'))deff1(x_block):print(x_block.shape)returnx_blockx1=np.arange(12*12).reshape(12,12)y=f1(x1)# prints (3,12)
Here, because the input pspec did not mention the mesh axis name'j', no inputarray axis is split over that mesh axis; similarly, because the second axis ofthe input array is not identified with (and hence split over) any mesh axis,application off1 gets a full view of the input along that axis.
When a mesh axis is not mentioned in an input pspec, we can always rewrite to aless efficient program where all mesh axes are mentioned but the caller performsajnp.tile, for example:
@partial(shard_map,mesh=m,in_specs=P('i','j'),out_specs=P('i','j'))deff2(x_block):print(x_block.shape)returnx_blockx=np.arange(12*12).reshape(12,12)x_=jnp.tile(x,(1,mesh.axis_size['j']))# x_ has shape (12, 24)y=f2(x_)# prints (3,12), and f1(x) == f2(x_)
In other words, because each input pspec can mention each mesh axis name zero orone times, rather than having to mention each name exactly once, we can say thatin addition to thejnp.split built into its input,shard_map also has ajnp.tile built into its input, at least logically (though the tiling may notneed to be carried out physically, depending on the arguments’ physical shardinglayout). The tiling to use is not unique; we could also have tiled along thefirst axis, and used the pspecP(('j','i'),None).
Physical data movement is possible on inputs, as each device needs to have acopy of the appropriate data.
Controlling how each output assembled by concatenation, block transposition, and untiling usingout_specs#
Analogously to the input side, each of theout_specs identifies some of thecorresponding output array’s axes with mesh axes by name, representing how theoutput blocks (one for each application of the body function, or equivalentlyone for each physical device) should be assembled back together to form thefinal output value. For example, in both thef1 andf2 examples above theout_specs indicate we should form the final output by concatenating togetherthe block results along both axes, resulting in both cases an arrayy of shape(12,24). (It’s an error if an output shape of the body function, i.e. anoutput block shape, has a rank too small for the concatenation described by thecorresponding output pspec.)
When a mesh axis name is not mentioned in an output pspec, it represents anun-tiling: when the user writes an output pspec which does not mention one ofthe mesh axis names, they promise that the output blocks are equal along thatmesh axis, and so only one block along that axis is used in the output (ratherthan concatenating all the blocks together along that mesh axis). For example,using the same mesh as above:
x=jnp.array([[3.]])z=shard_map(lambda:x,mesh=m,in_specs=(),out_specs=P('i','j'))()print(z)# prints the same as jnp.tile(x, (4, 2))z=shard_map(lambda:x,mesh=m,in_specs=(),out_specs=P('i',None))()print(z)# prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))z=shard_map(lambda:x,mesh=m,in_specs=(),out_specs=P(None,None))()print(z)# prints the same as jnp.tile(x, (1, 1)), or just x
Notice that the body function closing over an array value is equivalent topassing it as an augment with a corresponding input pspec ofP(None,None). Asanother example, following more closely to the other examples above:
@partial(shard_map,mesh=m,in_specs=P('i','j'),out_specs=P('i',None))deff3(x_block):returnjax.lax.psum(x_block,'j')x=np.arange(12*12).reshape(12,12)y3=f3(x)print(y3.shape)# (12,6)
Notice that the result has a second axis size of 6, half the size of the input’ssecond axis. In this case, the un-tile expressed by not mentioning the mesh axisname'j' in the output pspec was safe because of the collectivepsum, whichensures each output block is equal along the corresponding mesh axis. Here aretwo more examples where we vary which mesh axes are mentioned in the outputpspec:
@partial(shard_map,mesh=m,in_specs=P('i','j'),out_specs=P(None,'j'))deff4(x_block):returnjax.lax.psum(x_block,'i')x=np.arange(12*12).reshape(12,12)y4=f4(x)print(y4.shape)# (3,12)@partial(shard_map,mesh=m,in_specs=P('i','j'),out_specs=P(None,None))deff5(x_block):returnjax.lax.psum(x_block,('i','j'))y5=f5(x)print(y5.shape)# (3,6)
On the physical side, not mentioning a mesh axis name in an output pspecassembles anArray from the output device buffers with replicated layout alongthat mesh axis.
There is no runtime check that the output blocks are actually equal along a meshaxis to be un-tiled along, or equivalently that the corresponding physicalbuffers have equal values and thus can be interpreted as a replicated layout fora single logical array. But we can provide a static check mechanism which raisesan error on all potentially-incorrect programs.
Because theout_specs can mention mesh axis names zero or one times, andbecause they can be mentioned in any order, we can say that in addition to thejnp.concatenate built into its output,shard_map also has both an untile anda block transpose built into its output.
Physical data movement is not possible on outputs, no matter the output pspec.Instead,out_specs just encodes how to assemble the block outputs intoArrays, or physically how to interpret the buffers across devices as thephysical layout of a single logicalArray.
API Specification#
fromjax.shardingimportMeshSpecs=PyTree[PartitionSpec]defshard_map(f:Callable,mesh:Mesh,in_specs:Specs,out_specs:Specs)->Callable:...
where:
meshencodes devices arranged in an array and with associated axis names,just like it does forxmapand forsharding.NamedSharding;in_specsandout_specsarePartitionSpecs which canaffinely mentionaxis names frommesh(not separate logical names as inxmap) to expressslicing/unconcatenation and concatenation of inputs and outputs, respectively(not unstacking and stacking likepmapandxmapdo), with unmentionednames corresponding to replication and untiling(assert-replicated-so-give-me-one-copy), respectively;the shapes of the arguments passed to
fhave the same ranks as the argumentspassed toshard_map-of-f(unlikepmapandxmapwhere the ranks arereduced), and the shape of an argument tofis computed from the shapeshapeof the corresponding argument toshard_map-of-fand thecorrespondingPartitionSpecspec as roughlytuple(sz//(1ifnisNoneelsemesh.shape[n])forsz,ninzip(shape,spec));the body of
fcan apply collectives using names frommesh.
shmap is eager by default, meaning that we dispatch computationsprimitive-by-primitive, so that the user can employ Python control flow on fullyreplicated values and interactivepdb debugging to print any values. To stageout and end-to-end compile ashmapped function, just put ajit around it. Aconsequence is thatshmap doesn’t have its own dispatch and compilation pathslikexmap andpmap currently do; it’s just thejit path.
When it’s staged out by e.g. an enclosingjit, the lowering ofshmap toStableHLO is trivial: it just involves switching into ‘manual SPMD mode’ on theinputs, and switching back on the outputs. (We don’t currently plan to supportpartially-manual-partially-automatic modes.)
The interaction with effects is the same as withpmap.
The interaction with autodiff is also just likepmap (rather than attemptingthe new semantics thatxmap did, corresponding to having unmappedintermediates and hencegrad’sreduce_axes as well as makingpsumtranspose topbroadcast rather thanpsum). But it thus inherits an unsolvedproblem frompmap: in some cases, instead of transposingpsum topsum, andthus performing a backward passpsum corresponding to the forward passpsum,it can be beneficial to move the backward passpsum to elsewhere in thebackward pass, exploiting linearity. Many advancedpmap users addressed thischallenge by usingcustom_vjp to implementpsum_idrev andid_psumrevfunctions, but since it’s easy to accidentally leave those imbalanced, thattechnique is a foot-cannon. We have some ideas on how to provide thisfunctionality in a safer way.
When should you useshmap and when should you usepjit?#
One philosophy is: it is almost always simpler to write a program injit==pjit— but if a given part of the program is less optimized by the compiler than itcould be, drop intoshmap!
A realistic example#
Here’s howshmap might look in a transformer layer pass with a 2D weightgathered pattern (paper, Sec 3.2.3 on p. 5):
defmatmul_2D_wg_manual(xnorm,q_wi,layer):'''Calls a custom manual implementation of matmul_reducescatter'''# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]# -> (matmul)# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}# -> (reducescatter over x into X heads, B batches)# -> [batch, maxlen, heads.YZX, q_wi_per_head]withjax.named_scope('q_wi'):xnorm=intermediate_dtype(xnorm)q_wi=matmul_reducescatter('bte,hed->bthd',xnorm,params.q_wi,scatter_dimension=(0,2),axis_name='i',layer=layer)returnq_wiimportpartitioning.logical_to_physicalasl2physdefpjit_transformer_layer(hparams:HParams,layer:int,params:weights.Layer,sin:jnp.ndarray,cos:jnp.ndarray,kv_caches:Sequence[attention.KVCache],x:jnp.ndarray)->Tuple[jnp.ndarray,jnp.ndarray,jnp.ndarray]:"""Forward pass through a single layer, returning output, K, V."""defmy_layer(t,axis=0):"""Gets the parameters corresponding to a given layer."""returnlax.dynamic_index_in_dim(t,layer,axis=axis,keepdims=False)# 2D: [batch.Z, time, embed.XY]x=_with_sharding_constraint(x,('residual_batch','residual_time','residual_embed'))xnorm=_layernorm(x)# 2D: [batch, time, embed.X]xnorm=_with_sharding_constraint(xnorm,('post_norm_batch','time','post_norm_embed'))# jump into manual mode where you want to optimiseifmanual:q_wi=shard_map(matmul_2D_wg_manual,meshin_specs=(l2phys('post_norm_batch','time','post_norm_embed'),l2phys('layers','heads','embed','q_wi_per_head')),out_specs=l2phys('post_norm_batch','time','heads','q_wi_per_head'))(xnorm,q_wi,layer)else:q_wi=jnp.einsum('bte,hed->bthd',xnorm,my_layer(params.q_wi))# 2D: [batch, time, heads.YZX, None]q_wi=_with_sharding_constraint(q_wi,('post_norm_batch','time','heads','qkv'))q=q_wi[:,:,:,:hparams.qkv]q=_rope(sin,cos,q)# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements# swiGLU with full d_ff dimension, rather than 2/3 scaledwi0=q_wi[:,:,:,hparams.qkv:hparams.qkv+(hparams.ff//hparams.heads)]wi1=q_wi[:,:,:,hparams.qkv+(hparams.ff//hparams.heads):]kv=jnp.einsum('bte,ezd->btzd',xnorm,my_layer(params.kv))k=kv[:,:,0,:hparams.qkv]v=kv[:,:,0,hparams.qkv:]k=_rope(sin,cos,k)y_att=jnp.bfloat16(attention.attend(q,k,v,kv_caches,layer))y_mlp=special2.swish2(wi0)*wi1# 2D: [batch, time, heads.YZX, None]y_mlp=_with_sharding_constraint(y_mlp,('post_norm_batch','time','heads',None))y_fused=jnp.concatenate([y_att,y_mlp],axis=-1)# do the second half of the mlp and the self-attn projection in parallely_out=jnp.einsum('bthd,hde->bte',y_fused,my_layer(params.o_wo))# 2D: [batch.Z, time, embed.XY]y_out=_with_sharding_constraint(y_out,('residual_batch','residual_time','residual_embed'))z=y_out+xz=_with_sharding_constraint(z,('residual_batch','residual_time','residual_embed'))returnz,k,v
In the profile below, both the first and second matmul were replaced by manuallylowered versions, where the compute (fusions) are fully overlapped with thecommunication (ppermute)! One fun hint that we are using a latency optimisedvariant is that the ppmerute pixels are jittered — because there are twooverlapping ppermutes using opposite ICI axes at the same time!
All-to-all is much harder to overlap, so was left on the table.

Why don’tpmap orxmap already solve this?#
pmap was our first multi-device parallelism API. It follows theper-device-code-and-explicit-collectives school. But it had major shortcomingswhich make it unsuitable for today’s programs:
Mapping multiple axes required nested
pmaps. Not only are nestedpmapscumbersome to write, but also they make it difficult to control (or evenpredict) the device placement of data and computation, and difficult topreserve data sharding (see the next two bullets). Today’s programs requiremultiple axes of parallelism.Controlling device placement was impossible. Especially with multiple axesof parallelism, programmers need to control how those axes are aligned withhardware resources and their communication topologies. But (nested)
pmapdoesn’t offer control over how mapped program instances are placed onhardware; there’s just an automatic device order which the user can’t control.(Gopher’s use ofaxis_index_groupsand asingle un-nestedpmapwas essentially a hack to get around this byflattening multiple axes of parallelism down to one.)jit/pjitcomposability.jit-of-pmapis a performance footgun, asis nestingpmaps, as is e.g.scan-of-pmap, because sharding is notpreserved when returning from an innerpmap. To preserve sharding we wouldneed pattern matching on jaxprs to ensure we’re working with perfectly nestedpmaps, or a pmap just inside ajit. Moreover,pjitwas no help herebecausepmaptargets XLA replicas whilepjittargets the XLA SPMDPartitioner, and composing those two is hard.jax.Arraycompatibility (and hencepjitcompatibility). Because thesharding ofpmapoutputs can’t be expressed asShardings/OpShardings,due topmap’s stacking rather than concatenative semantics, the output of apmapcomputation can’t currently be passed to apjitcomputation withoutbouncing to host (or dispatching a reshaping computation).Multi-controller semantics (and hence
pjitcompatibility).Multi-controllerpmapconcatenates values across controllers, which works wellbut differs from single-controllerpmap’s stacking semantics. Morepractically, it precludes the use of non-fully-addressablejax.Arrayinputsand outputs as we use with multi-controllerpjit.Eager mode. We didn’t make
pmapeager-first, and though we eventually(after 4+ years!) added eager operation withdisable_jit(), the fact thatpmaphasjitfused into it means it has its own compilation and dispatchpath (actually two dispatch paths: in Python for handlingTracers, and inC++ for performance on rawArrayinputs!), a heavy implementation burden.Reshapes needed in the caller. A typical use case with
pmapon 8 devicesmight look like starting with a batch axis of size 128, reshaping it to splitinto two axes with sizes (8, 16), and thenpmapping over the first. Thesereshapes are awkward and the compiler often interprets them as copies insteadof view — increasing memory and time usage.
These shortcomings aren’t so bad when only doing batch data parallelism. Butwhen more parallelism is involved,pmap just can’t cut it!
xmap paved the way as a next-gen evolution ofpmap and solved (almost) all theseissues.shmap follows inxmap’s footsteps and solves these problems inessentially the same ways; indeed,shmap is like a specialized subset ofxmap(what some call the “hardxmap” subset), with a few tweaks.
For the initial prototype, we chose to implementshmap as a separate primitivefromxmap, because limiting the set of features it supports makes it easier tofocus on the core functionality. For example,shmap doesn’t allow unmappedintermediates, making it easier not to worry about the interactions betweennamed axes and autodiff. Furthermore, not having to reason about interactions ofall pairs of features makes it easier to add capabilities beyond what’simplemented inxmap today, such as support for eager mode.
Bothshmap andxmap share significant portions of the lowering code. Wecould consider merging both in the future, or even focusing solely onshmap,depending on how the usage will evolve.
