Explicit sharding (a.k.a. “sharding in types”)
Contents
Explicit sharding (a.k.a. “sharding in types”)#
JAX’s traditional automatic sharding leaves sharding decisions to the compiler.You can provide hints to the compiler usingjax.lax.with_sharding_constraint but for the most part you’re supposed to befocussed on the math while the compiler worries about sharding.
But what if you have a strong opinion about how you want your program sharded?With enough calls towith_sharding_constraint you can probably guide thecompiler’s hand to make it do what you want. But “compiler tickling” isfamously not a fun programming model. Where should you put the shardingconstraints? You could put them on every single intermediate but that’s a lotof work and it’s also easy to make mistakes that way because there’s no way tocheck that the shardings make sense together. More commonly, people add justenough sharding annotations to constrain the compiler. But this is a slowiterative process. It’s hard to know ahead of time what XLA’s GSPMD pass willdo (it’s a whole-program optimization) so all you can do is add annotations,inspect XLA’s sharding choices to see what happened, and repeat.
To fix this we’ve come up with a different style of sharding programming wecall “explicit sharding” or “sharding in types”. The idea is that shardingpropagation happens at the JAX level at trace time. Each JAX operation has asharding rule that takes the shardings of the op’s arguments and produces asharding for the op’s result. For most operations these rules are simple andobvious because there’s only one reasonable choice. But for some operations it’sunclear how to shard the result. In that case we ask the programmerto provide anout_sharding argument explicitly and we throw a (trace-time)error otherwise. Since the shardings are propagated at trace time they canalso bequeried at trace time too. In the rest of this doc we’ll describehow to use explicit sharding mode. Note that this is a new feature so weexpect there to be bugs and unimplemented cases. Please let us know when youfind something that doesn’t work! Also seeThe Training Cookbookfor a real-world machine learning training example that uses explicit sharding.
importjaximportnumpyasnpimportjax.numpyasjnpfromjax.shardingimportPartitionSpecasP,AxisType,get_abstract_mesh,reshardjax.config.update('jax_num_cpu_devices',8)
Setting up an explicit mesh#
The main idea behind explicit shardings, (a.k.a. sharding-in-types), is thatthe JAX-leveltype of a value includes a description of how the value is sharded.We can query the JAX-level type of any JAX value (or Numpy array, or Pythonscalar) usingjax.typeof:
some_array=np.arange(8)print(f"JAX-level type of some_array:{jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]
Importantly, we can query the type even while tracing under ajit (the JAX-level typeis almostdefined as “the information about a value we have access to whileunder a jit).
@jax.jitdeffoo(x):print(f"JAX-level type of x during tracing:{jax.typeof(x)}")returnx+xfoo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
These types show the shape and dtype of array but they don’t appear toshow sharding. (Actually, theydid show sharding, but the shardings weretrivial. See “Concrete array shardings”, below.) To start seeing someinteresting shardings we need to set up an explicit-sharding mesh.
jax.set_mesh can be used as a global setter or a context manager. We usejax.set_mesh in this notebook as a global setter. You can use it as a scopedcontext manager viawithjax.set_mesh(mesh).
mesh=jax.make_mesh((2,4),("X","Y"),axis_types=(AxisType.Explicit,AxisType.Explicit))jax.set_mesh(mesh)print(f"Current mesh is:{get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)Now we can create some sharded arrays usingreshard:
replicated_array=np.arange(8).reshape(4,2)sharded_array=reshard(replicated_array,P("X",None))print(f"replicated_array type:{jax.typeof(replicated_array)}")print(f"sharded_array type:{jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]sharded_array type: int32[4@X,2]
We should read the typef32[4@X,2] as “a 4-by-2 array of 32-bit floats whose first dimensionis sharded along mesh axis ‘X’. The array is replicated along all other meshaxes”
These shardings associated with JAX-level types propagate through operations. For example:
arg0=reshard(np.arange(4).reshape(4,1),P("X",None))arg1=reshard(np.arange(8).reshape(1,8),P(None,"Y"))result=arg0+arg1print(f"arg0 sharding:{jax.typeof(arg0)}")print(f"arg1 sharding:{jax.typeof(arg1)}")print(f"result sharding:{jax.typeof(result)}")
arg0 sharding: int32[4@X,1]arg1 sharding: int32[1,8@Y]result sharding: int32[4@X,8@Y]
We can do the same type querying under a jit:
@jax.jitdefadd_arrays(x,y):ans=x+yprint(f"x sharding:{jax.typeof(x)}")print(f"y sharding:{jax.typeof(y)}")print(f"ans sharding:{jax.typeof(ans)}")returnansadd_arrays(arg0,arg1)
x sharding: int32[4@X,1]y sharding: int32[1,8@Y]ans sharding: int32[4@X,8@Y]
Array([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 1, 2, 3, 4, 5, 6, 7, 8], [ 2, 3, 4, 5, 6, 7, 8, 9], [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)
That’s the gist of it. Shardings propagate deterministically at trace time andwe can query them at trace time.
Sharding rules and operations with ambiguous sharding#
Each op has a sharding rule which specifies its output sharding given itsinput shardings. A sharding rule may also throw a (trace-time) error. Each opis free to implement whatever sharding rule it likes, but the usual pattern isthe following: For each output axis we identify zero of more correspondinginput axes. The output axis is thensharded according to the “consensus” sharding of the corresponding input axes. i.e., it’sNone if the input shardings are allNone, and it’s the common non-None input shardingif there’s exactly one of them, or an error (requiring an explicit out_sharding=… kwarg) otherwise.
This procedure is done on an axis-by-axis basis. When it’s done, we might endup with an array sharding that mentions a mesh axis more than once, which isillegal. In that case we raise a (trace-time) sharding error and ask for anexplicit out_sharding.
Here are some example sharding rules:
nullary ops like
jnp.zeros,jnp.arange: These ops create arrays out of wholecloth so they don’t have input shardings to propagate. Their output isunsharded by default unless overridden by the out_sharding kwarg.unary elementwise ops like
sin,exp: The output is sharded the same as theinput.binary ops (
+,-,*etc.): Axis shardings of “zipped” dimensionsmust match (or beNone). “Outer product” dimensions (dimensions thatappear in only one argument) are sharded as they are in the input. If theresult ends up mentioning a mesh axis more than once it’s an error.reshape.Reshape is a particularly tricky op. An output axis can map to morethan one input axis (when reshape is used to merge axes) or just a partof an input axis (when reshape is used to split axes). Our usual rulesdon’t apply. Instead we treat reshape as follows. We strip away singletonaxes (these can’t be sharded anyway. Thenwe decide whether the reshape is a “split” (splitting a single axis intotwo or more adjacent axes), a “merge” (merging two or more adjacent axesinto a single one) or something else. If we have a split or merge case inwhich the split/merged axes are sharded as None then we shard theresulting split/merged axes as None and the other axes according to theircorresponding input axis shardings. In all other cases we throw an errorand require the user to provide anout_shardingargument.
JAX transformations and higher-order functions#
The staged-out representation of JAX programs is explicitly typed. (We callthe types “avals” but that’s not important.) In explicit-sharding mode, thesharding is part of that type. This means that shardings need to matchwherever types need to match. For example, the two sides of alax.cond need tohave results with matching shardings. And the carry oflax.scan needs to have thesame sharding at the input and the output of the scan body. And when youconstruct a jaxpr without concrete arguments usingmake_jaxpr you need toprovide shardings too. Certain JAX transformations perform type-leveloperations. Automatic differentation constructs a tangent type for each primaltype in the original computation (e.g.TangentOf(float)==float,TangentOf(int)==float0). With sharding in the types, this means that tangentvalues are sharded in the same way as their primal values. Vmap and scan alsodo type-level operations, they lift an array shape to a rank-augmented versionof that shape. That extra array axis needs a sharding. We can infer it from thearguments to the vmap/scan but they all need to agree. And a nullary vmap/scanneeds an explicit sharding argument just as it needs an explicit lengthargument.
Working around unimplemented sharding rules usingauto_axes#
The implementation of explicit sharding is still a work-in-progress and thereare plenty of ops that are missing sharding rules. For example,scatter andgather (i.e. indexing ops).
Normally we wouldn’t suggest using a feature with so many unimplemented cases,but in this instance there’s a reasonable fallback you can use:auto_axes.The idea is that you can temporarily drop into a context where the mesh axesare “auto” rather than “explicit”. You explicitly specify how you intend thefinal result of theauto_axes to be sharded as it gets returned to the calling context.
This works as a fallback for ops with unimplemented sharding rules. It alsoworks when you want to override the sharding-in-types type system. Forexample, suppose we want to add af32[4@X,4] to af32[4,4@X]. Oursharding rule for addition would throw an error: the result would need to bef32[4@X,4@X], which tries uses a mesh axis twice, which is illegal. But say youwant to perform the operation anyway, and you want the result to be sharded alongthe first axis only, likef32[4@X,4]. You can do this as follows:
fromjax.shardingimportauto_axes,explicit_axessome_x=reshard(np.arange(16).reshape(4,4),P("X",None))some_y=reshard(np.arange(16).reshape(4,4),P(None,"X"))try:some_x+some_yexceptExceptionase:print("ERROR!")print(e)print("=== try again with auto_axes ===")@auto_axesdefadd_with_out_sharding_kwarg(x,y):print(f"We're in auto-sharding mode here. This is the current mesh:{get_abstract_mesh()}")returnx+yresult=add_with_out_sharding_kwarg(some_x,some_y,out_sharding=P("X",None))print(f"Result type:{jax.typeof(result)}")
ERROR!add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]=== try again with auto_axes ===We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)Result type: int32[4@X,4]Using a mixture of sharding modes#
JAX now has three styles of parallelism:
Automatic sharding is where you treat all the devices as a single logicalmachine and write a “global view” array program for that machine. Thecompiler decides how to partition the data and computation across theavailable devices. You can give hints to the compiler using
with_sharding_constraint.Explicit Sharding (*new*) is similar to automatic sharding in thatyou’re writing a global-view program. The difference is that the shardingof each array is part of the array’s JAX-level type making it an explicitpart of the programming model. These shardings are propagated at the JAXlevel and queryable at trace time. It’s still the compiler’s responsibilityto turn the whole-array program into per-device programs (turning
jnp.sumintopsumfor example) but the compiler is heavily constrained by theuser-supplied shardings.Manual Sharding (
shard_map) is where you write a program from theperspective of a single device. Communication between devices happens viaexplicit collective operations like psum.
A summary table:
Mode | View? | Explicit sharding? | Explicit Collectives? |
|---|---|---|---|
Auto | Global | ❌ | ❌ |
Explicit | Global | ✅ | ❌ |
Manual | Per-device | ✅ | ✅ |
The current mesh tells us which sharding mode we’re in. We can query it withget_abstract_mesh:
print(f"Current mesh is:{get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)Sinceaxis_types=(Explicit,Explicit), this means we’re in fully-explicitmode. Notice that the sharding mode is associated with a meshaxis, not themesh as a whole. We can actually mix sharding modes by having a differentsharding mode for each mesh axis. Shardings (on JAX-level types) can onlymentionexplicit mesh axes and collective operations likepsum can onlymentionmanual mesh axes.
You can use theauto_axes API to beAuto over some mesh axes while beingExplicit over other. For example:
importfunctools@functools.partial(auto_axes,axes='X')defg(y):print(f'mesh inside g:{get_abstract_mesh()}')print(f'y.sharding inside g:{jax.typeof(y)= }',end='\n\n')returny*2@jax.jitdeff(arr1):print(f'mesh inside f:{get_abstract_mesh()}')x=jnp.sin(arr1)print(f'x.sharding:{jax.typeof(x)}',end='\n\n')z=g(x,out_sharding=P("X","Y"))print(f'z.sharding:{jax.typeof(z)}',end="\n\n")returnz+1some_x=reshard(np.arange(16).reshape(4,4),P("X","Y"))f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)x.sharding: float32[4@X,4@Y]mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit), device_kind=cpu, num_cores=None)y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])z.sharding: float32[4@X,4@Y]Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ], [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ], [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045], [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)
As you can see, insideg, the type ofarr1 isShapedArray(float32[4,4@Y]) which indicates it’s Explicit overY mesh axis while auto overX.
You can also use theexplicit_axes API to drop intoExplicit mode over some or all mesh axes.
auto_mesh=jax.make_mesh((2,4),("X","Y"),axis_types=(AxisType.Auto,AxisType.Auto))@functools.partial(explicit_axes,axes=('X','Y'))defexplicit_g(y):print(f'mesh inside g:{get_abstract_mesh()}')print(f'y.sharding inside g:{jax.typeof(y)= }')z=y*2print(f'z.sharding inside g:{jax.typeof(z)= }',end='\n\n')returnz@jax.jitdeff(arr1):print(f'mesh inside f:{get_abstract_mesh()}',end='\n\n')x=jnp.sin(arr1)z=explicit_g(x,in_sharding=P("X","Y"))returnz+1withjax.set_mesh(auto_mesh):some_x=jax.device_put(np.arange(16).reshape(4,4),P("X","Y"))f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])As you can see, all axes of mesh insidef are of typeAuto while insideg, they are of typeExplicit.Because of that, sharding is visible on the type of arrays insideg.
Concrete array shardings can mentionAuto mesh axis#
You can query the sharding of a concrete arrayx withx.sharding. Youmight expect the result to be the same as the sharding associated with thevalue’s type,jax.typeof(x).sharding. It might not be! The concrete array sharding,x.sharding, describes the sharding alongbothExplicit andAuto mesh axes. It’s the sharding that the compilereventually chose. Whereas the type-specificed sharding,jax.typeof(x).sharding, only describes the sharding alongExplicit meshaxes. TheAuto axes are deliberately hidden from the type because they’rethe purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,the type-specified sharding. For example:
defcompare_shardings(x):print(f"=== with mesh:{get_abstract_mesh()} ===")print(f"Concrete value sharding:{x.sharding.spec}")print(f"Type-specified sharding:{jax.typeof(x).sharding.spec}")my_array=jnp.sin(reshard(np.arange(8),P("X")))compare_shardings(my_array)@auto_axesdefcheck_in_auto_context(x):compare_shardings(x)returnxcheck_in_auto_context(my_array,out_sharding=P("X"))
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None) ===Concrete value sharding: P('X',)Type-specified sharding: P('X',)=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===Concrete value sharding: P('X',)Type-specified sharding: P(None,)Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 , -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)
Notice that at the top level, where we’re currently in a fullyExplicit meshcontext, the concrete array sharding and type-specified sharding agree. Butunder theauto_axes decorator we’re in a fullyAuto mesh context and thetwo shardings disagree: the type-specified sharding isP(None) whereas theconcrete array sharding isP("X") (though it could be anything! It’s up tothe compiler).
