Introduction to parallel programming
Contents
Introduction to parallel programming#
This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
The tutorial covers three modes of parallel computation:
Automatic sharding via
jax.jit(): The compiler chooses the optimal computation strategy (a.k.a. “the compiler takes the wheel”).Explicit Sharding (*new*) is similar to automatic sharding in thatyou’re writing a global-view program. The difference is that the shardingof each array is part of the array’s JAX-level type making it an explicitpart of the programming model. These shardings are propagated at the JAXlevel and queryable at trace time. It’s still the compiler’s responsibilityto turn the whole-array program into per-device programs (turning
jnp.sumintopsumfor example) but the compiler is heavily constrained by theuser-supplied shardings.Fully manual sharding with manual control using
jax.shard_map():shard_mapenables per-device code and explicit communication collectives
A summary table:
Mode | View? | Explicit sharding? | Explicit Collectives? |
|---|---|---|---|
Auto | Global | ❌ | ❌ |
Explicit | Global | ✅ | ❌ |
Manual | Per-device | ✅ | ✅ |
Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.
importjaxjax.config.update('jax_num_cpu_devices',8)
jax.devices()
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Key concept: Data sharding#
Key to all of the distributed computation approaches below is the concept ofdata sharding, which describes how data is laid out on the available devices.
How can JAX understand how the data is laid out across devices? JAX’s datatype, thejax.Array immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. Thejax.Array object is designed with distributed data and computation in mind. Everyjax.Array has an associatedjax.sharding.Sharding object, which describes which shard of the global data is required by each global device. When you create ajax.Array from scratch, you also need to create itsSharding.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
importnumpyasnpimportjax.numpyasjnparr=jnp.arange(32.0).reshape(4,8)arr.devices()
{CpuDevice(id=0)}arr.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
For a more visual representation of the storage layout, thejax.debug module provides some helpers to visualize the sharding of an array. For example,jax.debug.visualize_array_sharding() displays how the array is stored in memory of a single device:
jax.debug.visualize_array_sharding(arr)
CPU 0To create an array with a non-trivial sharding, you can define ajax.sharding specification for the array and pass this tojax.device_put().
Here, define aNamedSharding, which specifies an N-dimensional grid of devices with named axes, wherejax.sharding.Mesh allows for precise device placement:
fromjax.shardingimportPartitionSpecasPmesh=jax.make_mesh((2,4),('x','y'))sharding=jax.sharding.NamedSharding(mesh,P('x','y'))print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)/tmp/ipykernel_3339/3910416408.py:3: DeprecationWarning: The default axis_types will change in JAX v0.9.0 to jax.sharding.AxisType.Explicit. To maintain the old behavior, pass `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To opt-into the new behavior, pass `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names) mesh = jax.make_mesh((2, 4), ('x', 'y'))Passing thisSharding object tojax.device_put(), you can obtain a sharded array:
arr_sharded=jax.device_put(arr,sharding)print(arr_sharded)jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.] [ 8. 9. 10. 11. 12. 13. 14. 15.] [16. 17. 18. 19. 20. 21. 22. 23.] [24. 25. 26. 27. 28. 29. 30. 31.]]
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
1. Automatic parallelism viajit#
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to ajax.jit()-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
The XLA compiler behindjit includes heuristics for optimizing computations across multiple devices.In the simplest of cases, those heuristics boil down tocomputation follows data.
To demonstrate how auto-parallelization works in JAX, below is an example that uses ajax.jit()-decorated staged-out function: it’s a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
@jax.jitdeff_elementwise(x):return2*jnp.sin(x)+1result=f_elementwise(arr_sharded)print("shardings match:",result.sharding==arr_sharded.sharding)
shardings match: True
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here, you sum along the leading axis ofx, and visualize how the result values are stored across multiple devices (withjax.debug.visualize_array_sharding()):
@jax.jitdeff_contract(x):returnx.sum(axis=0)result=f_contract(arr_sharded)jax.debug.visualize_array_sharding(result)print(result)
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
[48. 52. 56. 60. 64. 68. 72. 76.]
The result is partially replicated: that is, the first two elements of the array are replicated on devices0 and4, the second on1 and5, and so on.
2. Explicit sharding#
The main idea behind explicit shardings, (a.k.a. sharding-in-types), is thatthe JAX-leveltype of a value includes a description of how the value is sharded.We can query the JAX-level type of any JAX value (or Numpy array, or Pythonscalar) usingjax.typeof:
some_array=np.arange(8)print(f"JAX-level type of some_array:{jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]
Importantly, we can query the type even while tracing under ajit (the JAX-level typeis almostdefined as “the information about a value we have access to whileunder a jit).
@jax.jitdeffoo(x):print(f"JAX-level type of x during tracing:{jax.typeof(x)}")returnx+xfoo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
To start seeing shardings in the type we need to set up an explicit-sharding mesh.
fromjax.shardingimportAxisTypemesh=jax.make_mesh((2,4),("X","Y"),axis_types=(AxisType.Explicit,AxisType.Explicit))
Now we can create some sharded arrays:
replicated_array=np.arange(8).reshape(4,2)sharded_array=jax.device_put(replicated_array,jax.NamedSharding(mesh,P("X",None)))print(f"replicated_array type:{jax.typeof(replicated_array)}")print(f"sharded_array type:{jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]sharded_array type: int32[4@X,2]
We should read the typeint32[4@X,2] as “a 4-by-2 array of 32-bit ints whose first dimensionis sharded along mesh axis ‘X’. The array is replicated along all other meshaxes”
These shardings associated with JAX-level types propagate through operations. For example:
arg0=jax.device_put(np.arange(4).reshape(4,1),jax.NamedSharding(mesh,P("X",None)))arg1=jax.device_put(np.arange(8).reshape(1,8),jax.NamedSharding(mesh,P(None,"Y")))@jax.jitdefadd_arrays(x,y):ans=x+yprint(f"x sharding:{jax.typeof(x)}")print(f"y sharding:{jax.typeof(y)}")print(f"ans sharding:{jax.typeof(ans)}")returnanswithjax.set_mesh(mesh):add_arrays(arg0,arg1)
x sharding: int32[4@X,1]y sharding: int32[1,8@Y]ans sharding: int32[4@X,8@Y]
That’s the gist of it. Shardings propagate deterministically at trace time andwe can query them at trace time.
3. Manual parallelism withshard_map#
In the automatic parallelism methods explored above, you can write a function as if you’re operating on the full dataset, andjit will split that computation across multiple devices. By contrast, withjax.shard_map() you write the function that will handle a single shard of data, andshard_map will construct the full function.
shard_map works by mapping a function across a particularmesh of devices (shard_map maps over shards). In the example below:
As before,
jax.sharding.Meshallows for precise device placement, with the axis names parameter for logical and physical axis names.The
in_specsargument determines the shard sizes. Theout_specsargument identifies how the blocks are assembled back together.
Note:jax.shard_map() code can work insidejax.jit() if you need it.
mesh=jax.make_mesh((8,),('x',))f_elementwise_sharded=jax.shard_map(f_elementwise,mesh=mesh,in_specs=P('x'),out_specs=P('x'))arr=jnp.arange(32)f_elementwise_sharded(arr)
/tmp/ipykernel_3339/4109912318.py:1: DeprecationWarning: The default axis_types will change in JAX v0.9.0 to jax.sharding.AxisType.Explicit. To maintain the old behavior, pass `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To opt-into the new behavior, pass `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names) mesh = jax.make_mesh((8,), ('x',))Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 , -0.9178486 , 0.44116902, 2.3139732 , 2.9787164 , 1.824237 , -0.08804226, -0.99998045, -0.07314587, 1.840334 , 2.9812148 , 2.3005757 , 0.42419338, -0.92279494, -0.50197446, 1.2997544 , 2.8258905 , 2.6733112 , 0.98229736, -0.69244087, -0.81115675, 0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777, -0.97606325, 0.19192469], dtype=float32)
The function you write only “sees” a single batch of the data, which you can check by printing the device local shape:
x=jnp.arange(32)print(f"global shape:{x.shape=}")deff(x):print(f"device local shape:{x.shape=}")returnx*2y=jax.shard_map(f,mesh=mesh,in_specs=P('x'),out_specs=P('x'))(x)
global shape: x.shape=(32,)device local shape: x.shape=(4,)
Because each of your functions only “sees” the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here’s what ashard_map of ajax.numpy.sum() looks like:
deff(x):returnjnp.sum(x,keepdims=True)jax.shard_map(f,mesh=mesh,in_specs=P('x'),out_specs=P('x'))(x)
Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
Your functionf operates separately on each shard, and the resulting summation reflects this.
If you want to sum across shards, you need to explicitly request it using collective operations likejax.lax.psum():
deff(x):sum_in_shard=x.sum()returnjax.lax.psum(sum_in_shard,'x')jax.shard_map(f,mesh=mesh,in_specs=P('x'),out_specs=P())(x)
Array(496, dtype=int32)
Because the output no longer has a sharded dimension, setout_specs=P() (recall that theout_specs argument identifies how the blocks are assembled back together inshard_map).
Comparing the three approaches#
With these concepts fresh in our mind, let’s compare the three approaches for a simple neural network layer.
Start by defining your canonical function like this:
@jax.jitdeflayer(x,weights,bias):returnjax.nn.sigmoid(x@weights+bias)
importnumpyasnprng=np.random.default_rng(0)x=rng.normal(size=(32,))weights=rng.normal(size=(32,4))bias=rng.normal(size=(4,))layer(x,weights,bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)
You can automatically run this in a distributed manner usingjax.jit() and passing appropriately sharded data.
If you shard the leading axis of bothx and makeweights fully replicated,then the matrix multiplication will automatically happen in parallel:
mesh=jax.make_mesh((8,),('x',))x_sharded=jax.device_put(x,jax.NamedSharding(mesh,P('x')))weights_sharded=jax.device_put(weights,jax.NamedSharding(mesh,P()))layer(x_sharded,weights_sharded,bias)
/tmp/ipykernel_3339/1456744650.py:1: DeprecationWarning: The default axis_types will change in JAX v0.9.0 to jax.sharding.AxisType.Explicit. To maintain the old behavior, pass `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To opt-into the new behavior, pass `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names) mesh = jax.make_mesh((8,), ('x',))Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)
Alternatively, you can use explicit sharding mode too:
explicit_mesh=jax.make_mesh((8,),('X',),axis_types=(AxisType.Explicit,))x_sharded=jax.device_put(x,jax.NamedSharding(explicit_mesh,P('X')))weights_sharded=jax.device_put(weights,jax.NamedSharding(explicit_mesh,P()))@jax.jitdeflayer_auto(x,weights,bias):print(f"x sharding:{jax.typeof(x)}")print(f"weights sharding:{jax.typeof(weights)}")print(f"bias sharding:{jax.typeof(bias)}")out=layer(x,weights,bias)print(f"out sharding:{jax.typeof(out)}")returnoutwithjax.set_mesh(explicit_mesh):layer_auto(x_sharded,weights_sharded,bias)
x sharding: float32[32@X]weights sharding: float32[32,4]bias sharding: float32[4]out sharding: float32[4]
Finally, you can do the same thing withshard_map, usingjax.lax.psum() to indicate the cross-shard collective required for the matrix product:
fromfunctoolsimportpartial@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=(P('x'),P('x',None),P(None)),out_specs=P(None))deflayer_sharded(x,weights,bias):returnjax.nn.sigmoid(jax.lax.psum(x@weights,'x')+bias)layer_sharded(x,weights,bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)
Controlling data and computation placement on devices#
Let’s look at the principles of data and computation placement in JAX.
In JAX, the computation follows data placement. JAX arrays have two placementproperties: 1) the device where the data resides; and 2) whether it iscommitted to the device or not (the data is sometimes referred to as beingsticky to the device).
By default, JAX arrays are placed uncommitted on the default device(jax.devices()[0]), which is the first GPU or TPU by default. If no GPU orTPU is present,jax.devices()[0] is the CPU. The default device can betemporarily overridden with thejax.default_device() context manager, orset for the whole process by setting the environment variableJAX_PLATFORMSor the absl flag--jax_platforms to “cpu”, “gpu”, or “tpu” (JAX_PLATFORMScan also be a list of platforms, which determines which platforms are availablein priority order).
>>>fromjaximportnumpyasjnp>>>print(jnp.ones(3).devices()){CudaDevice(id=0)}
Computations involving uncommitted data are performed on the default device andthe results are uncommitted on the default device.
Data can also be placed explicitly on a device usingjax.device_put() withadevice parameter, in which case the data becomescommitted to thedevice:
>>>importjax>>>fromjaximportdevice_put>>>arr=device_put(1,jax.devices()[2])>>>print(arr.devices()){CudaDevice(id=2)}
Computations involving some committed inputs will happen on the committed deviceand the result will be committed on the same device. Invoking an operation onarguments that are committed to more than one device will raise an error.
You can also usejax.device_put() without adevice parameter. If thedata is already on a device (committed or not), it’s left as-is. If the dataisn’t on any device—that is, it’s a regular Python or NumPy value—it’s placeduncommitted on the default device.
Jitted functions behave like any other primitive operations—they will follow thedata and will show errors if invoked on data committed on more than one device.
(BeforePR #6002 in March 2021there was some laziness in creation of array constants, so thatjax.device_put(jnp.zeros(...),jax.devices()[1]) or similar would actuallycreate the array of zeros onjax.devices()[1], instead of creating thearray on the default device then moving it. But this optimization was removedso as to simplify the implementation.)
(As of April 2020,jax.jit() has adevice parameter that affects the deviceplacement. That parameter is experimental, is likely to be removed or changed,and its use is not recommended.)
For a worked-out example, we recommend reading throughtest_computation_follows_data inmulti_device_test.py.
Next steps#
This tutorial serves as a brief introduction of sharded and parallel computation in JAX.
To learn about each SPMD method in-depth, check out these docs:
