Fault Tolerant Distributed JAX
Contents
Fault Tolerant Distributed JAX#
Recall thatmulti-controller JAX allows you to run a JAX program distributedacross multiple machines. By default, ifany of these machines fail, thenevery machine will fail. That is, multi-controller JAX is notfault-tolerant by default.
This article has three parts. In the first part, we’ll explain the basics ofhow to write fault tolerant multi-controller JAX programs. In the second part,we’ll show some example fault-tolerant multi-controller JAX programs. In thethird part, we’ll take a look under the covers at how multi-controller JAXimplements fault tolerance.
Warning
JAX’s support for fault tolerance is still experimental. It currently onlyworks fully on GPUs. It has rough edges, is probably buggy, and is subjectto change. Use at your own risk.
Part 1: Fault Tolerance Basics#
Fault Intolerant By Default#
By default, multi-controller JAX programs are not fault tolerant. Ifanyprocess crashes, thenall other processes will also intentionally crash. Tomake this concrete, consider the following trivial script,example.py, thatinitializes multi-controller JAX by callingjax.distributed.initialize andthen enters an infinite loop.
example.py#1fromabslimportapp 2fromabslimportflags 3fromcollections.abcimportSequence 4importjax 5importtime 6 7_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id") 8_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes") 91011defmain(_:Sequence[str])->None:12jax.distributed.initialize(13coordinator_address="localhost:9000",14num_processes=_NUM_PROCESSES.value,15process_id=_PROCESS_ID.value,16local_device_ids=[_PROCESS_ID.value],17heartbeat_timeout_seconds=10,18)19print(f'{jax.devices()=}')20print(f'{jax.local_devices()=}')21whileTrue:22print(time.time())23time.sleep(1)242526if__name__=="__main__":27app.run(main)
Runexample.py across four processes on a VM with four GPUs by runningthe following four commands, each in a different terminal. Thelocal_device_ids argument tojax.distributed.initialize ensures eachprocess is assigned only one of the four GPUs. We’ll explain theheartbeat_timeout_seconds argument in just a second.
pythonexample.py--i=0--n=4# in terminal 1pythonexample.py--i=1--n=4# in terminal 2pythonexample.py--i=2--n=4# in terminal 3pythonexample.py--i=3--n=4# in terminal 4
When you run these commands, you’ll see the processes dutifully printing outthe current time every second. Next, fail the fourth process:pkill-9-f'pythonexample.py--i=3--n=4'. After about ten seconds, the otherprocesses will also terminate and spit out error messages that look somethinglike this:
E0926 17:26:32.075402 157988 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task).F0926 17:26:32.075587 157988 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: The following tasks are unhealthy (stopped sending heartbeats):/job:jax_worker/replica:0/task:3The tasks have crashed. Check the task logs for an earlier error, or scheduler events (e.g. preemption, eviction) to debug further.RPC: /tensorflow.CoordinationService/PollForError [type.googleapis.com/tensorflow.CoordinationServiceError='']
When a process in a multi-controller JAX program notices that a peer processhas crashed, it decides to crash as well. The processesshare fate. Theheartbeat_timeout_seconds argument tojax.distributed.initializedetermines how long a process waits before concluding a peer process has died.The first three processes crash about ten seconds after you kill the fourthbecause we passedheartbeat_timeout_seconds=10 as an argument tojax.distributed.initialize.
Surviving Faults#
We can disable fate-sharing by adding the--xla_gpu_nccl_terminate_on_error=false flag and thejax_enable_recoverability configuration option toexample.py, as shownbelow:
1importos 2os.environ['XLA_FLAGS']='--xla_gpu_nccl_terminate_on_error=false' 3 4fromabslimportapp 5fromabslimportflags 6fromcollections.abcimportSequence 7importjax 8importtime 910_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id")11_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes")121314defmain(_:Sequence[str])->None:15jax.config.update("jax_enable_recoverability",True)16jax.distributed.initialize(17coordinator_address="localhost:9000",18num_processes=_NUM_PROCESSES.value,19process_id=_PROCESS_ID.value,20local_device_ids=[_PROCESS_ID.value],21heartbeat_timeout_seconds=10,22)23print(f'{jax.devices()=}')24print(f'{jax.local_devices()=}')25whileTrue:26print(time.time())27time.sleep(1)282930if__name__=="__main__":31app.run(main)
Again run the script across four processes and then kill the fourth. Noticethat now, the other three processes happily continue executing.
Next try failing process 0. Notice that all four processes terminate witherror messages that look something like the following:
E0929 17:42:48.594192 1044529 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task).F0929 17:42:48.594200 1044529 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: Failed to send RPC to coordination service. Either the leader task was preempted/died/restarted unexpectedly or this task is experiencing network issues. Check earlier logs from 1) this task, 2) the leader (usually slice 0 task 0), and 3) cluster scheduler to debug further.Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/PollForError::UNKNOWN:Error received from peer {grpc_message:"Socket closed", grpc_status:14}Process 0 is special. If process 0 fails, every process will fail, even withfate-sharing disabled. Why? Process 0 runs an RPC service called thecoordination service that all processes use to coordination with each other. Ifthe coordination service fails, all other processes have no choice but to fail.SeePart 3: Implementation Details for more details.
Getting Stuck in Collectives#
example.py is now able to survive faults, but the processes do notcommunicate with each other at all. Any realistic multi-controller JAX programwould involve communication between the processes (otherwise, what’s the pointof using multi-controller JAX?). Let’s editexample.py so that theprocesses perform a collectivejnp.sum in every iteration of the loop.
1importos 2os.environ['XLA_FLAGS']='--xla_gpu_nccl_terminate_on_error=false' 3 4fromabslimportapp 5fromabslimportflags 6fromcollections.abcimportSequence 7importjax 8importjax.numpyasjnp 9importtime1011_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id")12_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes")131415defmain(_:Sequence[str])->None:16jax.config.update("jax_enable_recoverability",True)17jax.distributed.initialize(18coordinator_address="localhost:9000",19num_processes=_NUM_PROCESSES.value,20process_id=_PROCESS_ID.value,21local_device_ids=[_PROCESS_ID.value],22heartbeat_timeout_seconds=10,23)24print(f'{jax.devices()=}')25print(f'{jax.local_devices()=}')2627n=jax.device_count()28jax.set_mesh(jax.make_mesh((n,),("i",)))29x=jax.device_put(jnp.arange(n),jax.P("i"))30whileTrue:31print(jnp.sum(x))32time.sleep(1)333435if__name__=="__main__":36app.run(main)
In the highlighted code above, the processes create an arrayx shardedacross the four processes and then perform a distributedjnp.sum. Again runthe program and fail the fourth process. You’ll notice that the first threeprocess do not crash, but they do getstuck. By default, if a process failswhile participating in a distributed computation (likejnp.sum), then therest of the processes participating in the computation will get stuckforever.
Cancelling Collectives#
We can avoid getting stuck by cancelling collectives with a failed participant.We can enable collective cancelling by providing a few more flags andenvironment variables, highlighted below.
1importos 2os.environ['XLA_FLAGS']=' '.join([ 3'--xla_gpu_nccl_terminate_on_error=false', 4'--xla_gpu_nccl_async_execution=true', 5'--xla_gpu_nccl_blocking_communicators=false', 6]) 7os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE']='1' 8os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT']='1' 910fromabslimportapp11fromabslimportflags12fromcollections.abcimportSequence13importjax14importjax.numpyasjnp15importtime1617_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id")18_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes")192021defmain(_:Sequence[str])->None:22jax.config.update("jax_enable_recoverability",True)23jax.distributed.initialize(24coordinator_address="localhost:9000",25num_processes=_NUM_PROCESSES.value,26process_id=_PROCESS_ID.value,27local_device_ids=[_PROCESS_ID.value],28heartbeat_timeout_seconds=10,29)30print(f'{jax.devices()=}')31print(f'{jax.local_devices()=}')3233# Don't do this. Use live_devices instead.34fromjax.experimental.multihost_utilsimport_live_devices35_live_devices(jax._src.distributed.global_state.client,jax.devices())3637n=jax.device_count()38jax.set_mesh(jax.make_mesh((n,),("i",)))39x=jax.device_put(jnp.arange(n),jax.P("i"))40whileTrue:41print(jnp.sum(x))42time.sleep(1)434445if__name__=="__main__":46app.run(main)
We also need to insert a call tojax.experimental.multihost_utils._live_devices to make the script work. Youshould normally not do this. You should instead use thelive_devices APIthat we’ll introduce momentarily. For now,_live_devices is a hack to getthe script working before we explain the proper API.
Again run the script and fail the fourth process. The first three processeswill be stuck in their call tojnp.sum, but after about ten seconds, thecall will be cancelled andjnp.sum will raise an exception that lookssomething like this:
jaxlib._jax.XlaRuntimeError:FAILED_PRECONDITION:Taskwithincarnationid3446767950926952685isnotconnected
Knowing Who’s Alive#
After a process dies, the remainingalive procesess need to learn who is deadand who is alive. For this, we can use the core JAX fault tolerance API:live_devices.live_devices is a context manager that takes a list ofdevices as an argument and returns the subset of these devices that are alive.Below, we editexample.py to calllive_devices.
1importos 2os.environ['XLA_FLAGS']=' '.join([ 3'--xla_gpu_nccl_terminate_on_error=false', 4'--xla_gpu_nccl_async_execution=true', 5'--xla_gpu_nccl_blocking_communicators=false', 6]) 7os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE']='1' 8os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT']='1' 910fromabslimportapp11fromabslimportflags12fromcollections.abcimportSequence13fromjax.experimental.multihost_utilsimportlive_devices14importjax15importjax.numpyasjnp16importtime1718_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id")19_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes")202122defmain(_:Sequence[str])->None:23jax.config.update("jax_enable_recoverability",True)24jax.distributed.initialize(25coordinator_address="localhost:9000",26num_processes=_NUM_PROCESSES.value,27process_id=_PROCESS_ID.value,28local_device_ids=[_PROCESS_ID.value],29heartbeat_timeout_seconds=10,30)31print(f'{jax.devices()=}')32print(f'{jax.local_devices()=}')3334whileTrue:35try:36withlive_devices(jax.devices())asdevices:37print(f'{devices=}')38n=len(devices)39jax.set_mesh(jax.make_mesh((n,),("i",),devices=devices))40x=jax.device_put(jnp.arange(n),jax.P("i"))41print(jnp.sum(x))42exceptExceptionase:43print('FAIL:',e)44else:45print('PASS')46time.sleep(1)474849if__name__=="__main__":50app.run(main)
In the highlighted code above, we calllive_devices with all devices(jax.devices()) to get the setdevices of live devices. We then shardarrayx over these devices and perform ajnp.sum. If a process failswhile executing thejnp.sum, thenjnp.sum will be cancelled and raisean exception on the remaining live devices. Technically, the collective is notguaranteed to fail. We’ll revisit this inAtomicity. For now, assume itwill fail.
Note
jax.devices() always returns the set ofall devices, even if some ofthese devices are on failed processes. Usejax.experimental.multihost_utils.live_devices to learn which of thesedevices are live.
Again run the script and fail the fourth process. Notice that the remainingthree alive processes catch the exception raised byjnp.sum and continue tothe next iteration of the while loop. In this next iteration,devices doesnot include the device on the failed fourth process. The three alive processescontinue to execute correctly even though the fourth process is dead.
Next, restart the fourth process. Notice that after the fourth processrestarts, its device is again included in the set of alive devices returned bylive_devices. All four processes then continue executing normally.
At first blush,live_devices seems trivial. You give it a list of devices,and it returns the ones that are alive. How complicated can that be?Unfortunately, as withmany things in distributed systems, there are a lotsubtleties to iron out. Next, we explain thebarrier semantics andatomicity properties oflive_devices.
Barrier Semantics#
Recall that every process in amulti-controller JAX program should run inlockstep. The processes should execute the same instructions in the same order.Failing to do so willalmost certainly lead to deadlocks, crashes, oranomalous behavior.
In the context oflive_devices, we need to ensure that every process agreeson which processes are currently alive. This is difficult to ensure becauseevery process is executing independently at potentially different speeds andprocesses can fail at any time. Consider again theexample.py script fromabove running on four processes. Imagine process 1 and 2 calllive_devices,then process 4 fails, and then process 3 callslive_devices. Process 1 and2 might think process 4 is alive while process 3 thinks it is dead.
To avoid situations like these,live_devices guarantees that it returns thesame set of live devices to every process. It accomplishes this using abarrier. A call tolive_devicess(devices) blocks until every live processhosting a device indevices has also calledlive_devices. Once everylive process is in thelive_devices barrier,live_devices returns thesame set of live devices to every process.
Important
live_devices uses a barrier to ensure that it willalways return thesame set of live devices to every live process.
Becauselive_devices implements a barrier it is susceptible to deadlock ifused improperly. We recommend only having a singlewithlive_devices blockin a program. Multiple calls tolive_devices is hard to reason about andcan lead to deadlock.
SeePart 3: Implementation Details for details on how thelive_devices barrier is implementedas well as a formal semantics based onlinearizability.
Atomicity#
A distributed computation isatomic if every participant in the computationagrees on whether the operation succeeds or fails. In theexample.py scriptabove, we saw that when a process failed during the execution of ajnp.sum,thenjnp.sum would abort and raise an exception on the remaining liveprocesses. Sojnp.sum is atomic?
Unfortunately, it’s not.
When a process fails during the execution of a collective operation (likejnp.sum), the remaining processes may cancel the operation and raise anexception or they may complete the operation successfully. Collectiveoperations in JAX do not have any inherent atomicity properties.
If collective operations are not atomic, however, then multi-controller JAXprocesses might diverge. For example, if a process fails during a training stepof a machine learning model, some processes might detect the failure and rollthe model back to a checkpoint while other processes might think the stepsucceeded and keep training.
To avoid the complexities of non-atomic execution,live_devices providesits own atomicity guarantees despite the fact that collectives are not atomic.Specifically, the body of awithlive_devices block is guaranteed to eithercomplete successfully on all processes or raise an exception on all processes.More concretely, if we consider the code snippet below, either every processexecutes branch A or every process executes branch B. It is impossible for someprocesses to execute A while others execute B.
try:withlive_devices(jax.live_devices())asdevices:...exceptExceptionase:...# Branch Aelse:...# Branch B
Warning
Awithlive_devices block does not guarantee atomicity if the codeblock non-deterministically raises exceptions for reasons other thancollectives that fail because of a crashed process. For example, if oneprocess raises an exception because it runs out of memory, this exceptionwill not be propagated to the other processes.
Recall that JAX usesasynchronous dispatch. Operations likejnp.sum donot block until the operation is complete. Instead, they returnjax.Arraysthat act as futures. This asynchrony can interact withlive_devices inunexpected ways. For example, consider the following code that performs ajnp.sum, assigns the result toy, and then printsy:
x=...y=...try:withlive_devices(jax.live_devices())asdevices:y=jnp.sum(x)exceptExceptionase:...# Branch Aelse:...# Branch Bprint(y)
Imagine that thewithlive_devices block executes successfully on allprocesses. That is, all processes execute branch B. This only guarantees thatevery process successfully created a future and assigned it toy. Theactual computation of thejnp.sum may be delayed until outside the block.Thus, some processes might successfully complete thejnp.sum and print thevalue ofy while other processes fail to complete thejnp.sum and raisean exception when trying to printy.
To avoid this, usejax.block_until_ready to ensure that computations areperformed within thewithlive_devices block. The code snippet below, whichnow callsjax.block_until_ready when assigning toy, guarantees thatevery process will successfully execute thejnp.sum or every process willraise an exception.
x=...y=...try:withlive_devices(jax.live_devices())asdevices:y=jax.block_until_ready(jnp.sum(x))exceptExceptionase:...# Branch Aelse:...# Branch Bprint(y)
SeePart 3: Implementation Details for details on how atomicity is implemented.
Part 2: Examples#
live_devices is not a panacea; it is a tool. It does not magically makemulti-controller JAX programs fault tolerant. Rather, it allows you toimplement fault tolerance yourself in the way that is best for yourapplication.
The exact details of how you implement fault-tolerance will vary greatly basedon the nature of your application. In this section, we present some examples ofhow to uselive_devices. The examples are meant to be illustrative but notprescriptive. There are many other ways to implement fault tolerance.
Example 1: Fault Tolerant Data Parallel Training#
In this example, we train a trivial single-parameter linear model (\(y =\alpha x\)) with data parallelism across four processes. The example iscontrived—you would never train a model with a single parameter across fourmachines—but we intentionally keep the model simple to focus on faulttolerance.
Data parallelism makes implementing fault tolerance relatively straightforward.Because every process has a full copy of the model weights, if a process fails,we can simply ignore it and continue training. This example tolerates anarbitrary number of process failures (excluding process 0), but once a processfails, we assume it does not recover. The next example shows how to handleprocess recovery.
First, we set some flags to disable fate-sharing and enable collectivecancelling. We also make the necessary imports and define some flags.
1importos 2os.environ['XLA_FLAGS']=' '.join([ 3'--xla_gpu_nccl_terminate_on_error=false', 4'--xla_gpu_nccl_async_execution=true', 5'--xla_gpu_nccl_blocking_communicators=false', 6]) 7os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE']='1' 8os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT']='1' 910fromabslimportapp11fromabslimportflags12fromcollections.abcimportSequence13fromjax.experimental.multihost_utilsimportlive_devices14importjax15importjax.numpyasjnp16importtime1718_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id")19_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes")
Next, we define areplicated function that returns an array replicatedacross a set of devices. Note thatreplicated doesn’t actually move anydata. It assumes the argumentx already has equal value across allprocesses. It simply returns a new view of that data, in a process-spanningjax.Array with a replicated sharding.
21defreplicated(x:jax.Array,devices:list[jax.Device]):22"""Return x replicated across the provided devices.2324 Note that replicated(x) doesn't actually move any data. It simply creates a25 logically replicated array with x as the local replica.26 """27n=len(devices)28mesh=jax.make_mesh((n,),("i",),devices=devices)29spec=jax.sharding.PartitionSpec(None)30sharding=jax.sharding.NamedSharding(mesh,spec)31shards=[32jax.device_put(x.addressable_shards[0].data,d)fordindevices33ifd.process_index==jax.process_index()34]35returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards)
We define a similarsharded function that returns an array sharded across aset of devices. Again,sharded is not actually moving any data betweenprocesses.
38defsharded(x:jax.Array,devices:list[jax.Device]):39"""Return x sharded across the provided devices.4041 Note that sharded(x) doesn't actually move any data. It simply creates a42 logically sharded array. x should have the same shape as the global array.43 """44n=len(devices)45mesh=jax.make_mesh((n,),("i",),devices=devices)46spec=jax.sharding.PartitionSpec("i")47sharding=jax.sharding.NamedSharding(mesh,spec)48m=sharding.addressable_devices_indices_map(x.shape)49shards=[jax.device_put(x[m[d]],d)fordinjax.local_devices()]50returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards)
Now, we’re ready to start writing our training loop. We begin by initializingmulti-controller JAX by callingjax.distributed.initialize.
53defmain(_:Sequence[str])->None:54# Parse command line arguments and initialize multi-controller JAX.55jax.config.update("jax_enable_recoverability",True)56jax.distributed.initialize(coordinator_address="localhost:8000",57process_id=_PROCESS_ID.value,58num_processes=_NUM_PROCESSES.value,59local_device_ids=[_PROCESS_ID.value],60heartbeat_timeout_seconds=10)61print(f'{jax.devices()=}')62print(f'{jax.local_devices()=}')
Then, we define our simple linear model, generate some random training data,and initialize some basic hyperparameters.
64# Initialize the model's weights.65keys=iter(jax.random.split(jax.random.key(seed=42),num=3))66weights=jax.random.normal(next(keys),shape=(1,))6768# We'll learn a trivial linear model: a*x.69defpredict(weights,X):70returnweights*X7172# We'll use mean squared error loss.73defloss(weights,X,Y):74returnjnp.mean((predict(weights,X)-Y)**2)7576# Initialize the (noisy) training data with a=10.77X=jax.random.permutation(next(keys),jnp.arange(-300.,300.))78Y=10*X+jax.random.normal(next(keys),X.shape)7980# Hyperparameters.81loss_and_grad=jax.jit(jax.value_and_grad(loss))82learning_rate=1e-683device_batch_size=10
Finally, we enter the main training loop.
85step=0 86whileTrue: 87try: 88withlive_devices(jax.devices())asdevices: 89print(f'=== Running step{step} with live devices ={devices} ===') 90 91# Replicate the model weights. 92weights=replicated(weights,devices) 93 94# Shard the batch. 95batch_size=device_batch_size*len(devices) 96start=(step*batch_size)%len(X) 97stop=start+batch_size 98X_batch=sharded(X[start:stop],devices) 99Y_batch=sharded(Y[start:stop],devices)100101# Compute gradients and update weights.102l,grad=loss_and_grad(weights,X_batch,Y_batch)103new_weights=jax.block_until_ready(weights-learning_rate*grad)104exceptExceptionase:105print(f'Step{step} failed:{e}')106else:107print(f'Step{step} succeeded: loss ={l}')108step+=1109weights=new_weights110111time.sleep(1)
Every iteration of the loop, we call
live_devicesto learn which devicesare currently alive.We then ensure that the model weights are replicated across these devices andensure that the training data is sharded across these devices. Note that thisdoesn’t actually move any data between the devices; it simply creates JAXarrays with the appropriate replication and sharding metadata.
We call
loss_and_gradto compute the gradient of the weights with respectto the current batch of data and then compute the new weights. Notice that weassign the new weights tonew_weightsrather than assigning toweightsin case the training step fails. We also calljax.block_until_readyto ensure that every process has computed the newweights when we exit thelive_devicesblock.If no processes failed during the execution of the training step, then the
elsebranch is taken. The step is incremented, andweightsisupdated. Otherwise, an exception will be raised and theexceptbranch istaken. In this case, we do not updatesteporweightsand retry thestep on the next iteration with the new set of live devices.
Here is the full example:
1importos 2os.environ['XLA_FLAGS']=' '.join([ 3'--xla_gpu_nccl_terminate_on_error=false', 4'--xla_gpu_nccl_async_execution=true', 5'--xla_gpu_nccl_blocking_communicators=false', 6]) 7os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE']='1' 8os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT']='1' 9 10fromabslimportapp 11fromabslimportflags 12fromcollections.abcimportSequence 13fromjax.experimental.multihost_utilsimportlive_devices 14importjax 15importjax.numpyasjnp 16importtime 17 18_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id") 19_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes") 20 21defreplicated(x:jax.Array,devices:list[jax.Device]): 22"""Return x replicated across the provided devices. 23 24 Note that replicated(x) doesn't actually move any data. It simply creates a 25 logically replicated array with x as the local replica. 26 """ 27n=len(devices) 28mesh=jax.make_mesh((n,),("i",),devices=devices) 29spec=jax.sharding.PartitionSpec(None) 30sharding=jax.sharding.NamedSharding(mesh,spec) 31shards=[ 32jax.device_put(x.addressable_shards[0].data,d)fordindevices 33ifd.process_index==jax.process_index() 34] 35returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards) 36 37 38defsharded(x:jax.Array,devices:list[jax.Device]): 39"""Return x sharded across the provided devices. 40 41 Note that sharded(x) doesn't actually move any data. It simply creates a 42 logically sharded array. x should have the same shape as the global array. 43 """ 44n=len(devices) 45mesh=jax.make_mesh((n,),("i",),devices=devices) 46spec=jax.sharding.PartitionSpec("i") 47sharding=jax.sharding.NamedSharding(mesh,spec) 48m=sharding.addressable_devices_indices_map(x.shape) 49shards=[jax.device_put(x[m[d]],d)fordinjax.local_devices()] 50returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards) 51 52 53defmain(_:Sequence[str])->None: 54# Parse command line arguments and initialize multi-controller JAX. 55jax.config.update("jax_enable_recoverability",True) 56jax.distributed.initialize(coordinator_address="localhost:8000", 57process_id=_PROCESS_ID.value, 58num_processes=_NUM_PROCESSES.value, 59local_device_ids=[_PROCESS_ID.value], 60heartbeat_timeout_seconds=10) 61print(f'{jax.devices()=}') 62print(f'{jax.local_devices()=}') 63 64# Initialize the model's weights. 65keys=iter(jax.random.split(jax.random.key(seed=42),num=3)) 66weights=jax.random.normal(next(keys),shape=(1,)) 67 68# We'll learn a trivial linear model: a*x. 69defpredict(weights,X): 70returnweights*X 71 72# We'll use mean squared error loss. 73defloss(weights,X,Y): 74returnjnp.mean((predict(weights,X)-Y)**2) 75 76# Initialize the (noisy) training data with a=10. 77X=jax.random.permutation(next(keys),jnp.arange(-300.,300.)) 78Y=10*X+jax.random.normal(next(keys),X.shape) 79 80# Hyperparameters. 81loss_and_grad=jax.jit(jax.value_and_grad(loss)) 82learning_rate=1e-6 83device_batch_size=10 84 85step=0 86whileTrue: 87try: 88withlive_devices(jax.devices())asdevices: 89print(f'=== Running step{step} with live devices ={devices} ===') 90 91# Replicate the model weights. 92weights=replicated(weights,devices) 93 94# Shard the batch. 95batch_size=device_batch_size*len(devices) 96start=(step*batch_size)%len(X) 97stop=start+batch_size 98X_batch=sharded(X[start:stop],devices) 99Y_batch=sharded(Y[start:stop],devices)100101# Compute gradients and update weights.102l,grad=loss_and_grad(weights,X_batch,Y_batch)103new_weights=jax.block_until_ready(weights-learning_rate*grad)104exceptExceptionase:105print(f'Step{step} failed:{e}')106else:107print(f'Step{step} succeeded: loss ={l}')108step+=1109weights=new_weights110111time.sleep(1)112113114if__name__=="__main__":115app.run(main)
Example 2: Fault Tolerant Data Parallel Training With Recovery#
Now, we modify the example above to allow failed processes to recover. When aprocess recovers, it needs to receive the current step and model weights.Because we assume process 0 never fails—recall that if process 0 fails, everyprocess will fail—we have process 0 send the current step and weights torecovering processes.
First, we definesend andrecv functions that use ashard_map tosend data from one device to another. The sender callssend, and thereceiver callsrecv.
55defsend(x:jax.Array,from_device:jax.Device,to_device:jax.Device):56"""Sends x from one device to another."""57assertisinstance(x,jax.Array)58devices=[from_device,to_device]59psum=lambdax:jax.lax.psum(x,"i")60mesh=jax.make_mesh((2,),("i",),devices=devices)61spec=jax.sharding.PartitionSpec(None)62x=replicated(x,[from_device,to_device])63shard_map.shard_map(psum,mesh=mesh,in_specs=spec,out_specs=spec)(x)646566defrecv(x:jax.Array,from_device:jax.Device,to_device:jax.Device):67"""Receives x from a matching send."""68assertisinstance(x,jax.Array)69to_device=jax.local_devices()[0]70devices=[from_device,to_device]71psum=lambdax:jax.lax.psum(x,"i")72mesh=jax.make_mesh((2,),("i",),devices=devices)73spec=jax.sharding.PartitionSpec(None)74x=jnp.zeros_like(x)75x=replicated(x,[from_device,to_device])76returnshard_map.shard_map(psum,mesh=mesh,in_specs=spec,out_specs=spec)(x)
allgather performs an AllGather of a single float across a set of devices.
79defallgather(x:float,devices:list[jax.Device])->list[float]:80"""Performs an AllGather across the provided devices."""81n=len(devices)82mesh=jax.make_mesh((n,),("i",),devices=devices)83spec=jax.sharding.PartitionSpec('i')84p=lambdax:jax.lax.all_gather(x,"i",tiled=True)85f=jax.shard_map(p,mesh=mesh,in_specs=spec,out_specs=spec)86returnjax.block_until_ready(f(np.array([x]*len(devices)))).addressable_shards[0].data
Finally, we modify the training loop to handle recovering processes, as shownin the highlighted code below.
121step=0122whileTrue:123try:124withlive_devices(jax.devices())asdevices:125print(f'=== Running step{step} with live devices ={devices} ===')126127# Handle recovering devices. A device is recovering if its step doesn't128# match process 0's step. We assume process 0 never fails.129print('all gathering steps...')130steps=allgather(step,devices)131print(f'{steps=}')132recovering=[dford,sinzip(devices,steps)ifs!=steps[0]]133fordinrecovering:134# Process 0 sends weights and step to the recovering devices.135ifjax.process_index()==0:136print('sending...')137send(weights,jax.devices()[0],d)138send(jnp.array([step]),jax.devices()[0],d)139elifd.process_index==jax.process_index():140print('receiving...')141weights=recv(weights,jax.devices()[0],d)142step=recv(jnp.array([step]),jax.devices()[0],d)[0]143144# Replicate the model weights.145weights=replicated(weights,devices)146147# Shard the batch.148batch_size=device_batch_size*len(devices)149start=(step*batch_size)%len(X)150stop=start+batch_size151X_batch=sharded(X[start:stop],devices)152Y_batch=sharded(Y[start:stop],devices)153154# Compute gradients and update weights.155l,grad=loss_and_grad(weights,X_batch,Y_batch)156new_weights=jax.block_until_ready(weights-learning_rate*grad)157exceptExceptionase:158print(f'Step{step} failed:{e}')159else:160print(f'Step{step} succeeded: loss ={l}')161step+=1162weights=new_weights163164time.sleep(1)
Recovery is a two-step process. First, we need to detect which processes arerecovering. Second, we need process 0 to send the step and weights to therecovering processes.
To detect which processes are recovering, we perform an AllGather on alllive processes’ steps. When a failed process recovers, its
stepwill be0, while thestepon process0will be some positive number, soif a process’ step is not equal to process 0’s step, then it is recovering.Then, we call the
sendandrecvfunctions we defined above totransfer the current step and model weights from process 0 to the recoveringprocesses.
Here is the full example:
1importos 2os.environ['XLA_FLAGS']=' '.join([ 3'--xla_gpu_nccl_terminate_on_error=false', 4'--xla_gpu_nccl_async_execution=true', 5'--xla_gpu_nccl_blocking_communicators=false', 6]) 7os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE']='1' 8os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT']='1' 9 10fromabslimportapp 11fromabslimportflags 12fromcollections.abcimportSequence 13fromjax.experimental.multihost_utilsimportlive_devices 14fromjax.experimentalimportshard_map 15importjax 16importjax.numpyasjnp 17importnumpyasnp 18importtime 19 20_PROCESS_ID=flags.DEFINE_integer("i",-1,"Process id") 21_NUM_PROCESSES=flags.DEFINE_integer("n",-1,"Number of processes") 22 23defreplicated(x:jax.Array,devices:list[jax.Device]): 24"""Return x replicated across the provided devices. 25 26 Note that replicated(x) doesn't actually move any data. It simply creates a 27 logically replicated array with x as the local replica. 28 """ 29n=len(devices) 30mesh=jax.make_mesh((n,),("i",),devices=devices) 31spec=jax.sharding.PartitionSpec(None) 32sharding=jax.sharding.NamedSharding(mesh,spec) 33shards=[ 34jax.device_put(x.addressable_shards[0].data,d)fordindevices 35ifd.process_index==jax.process_index() 36] 37returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards) 38 39 40defsharded(x:jax.Array,devices:list[jax.Device]): 41"""Return x sharded across the provided devices. 42 43 Note that sharded(x) doesn't actually move any data. It simply creates a 44 logically sharded array. x should have the same shape as the global array. 45 """ 46n=len(devices) 47mesh=jax.make_mesh((n,),("i",),devices=devices) 48spec=jax.sharding.PartitionSpec("i") 49sharding=jax.sharding.NamedSharding(mesh,spec) 50m=sharding.addressable_devices_indices_map(x.shape) 51shards=[jax.device_put(x[m[d]],d)fordinjax.local_devices()] 52returnjax.make_array_from_single_device_arrays(x.shape,sharding,shards) 53 54 55defsend(x:jax.Array,from_device:jax.Device,to_device:jax.Device): 56"""Sends x from one device to another.""" 57assertisinstance(x,jax.Array) 58devices=[from_device,to_device] 59psum=lambdax:jax.lax.psum(x,"i") 60mesh=jax.make_mesh((2,),("i",),devices=devices) 61spec=jax.sharding.PartitionSpec(None) 62x=replicated(x,[from_device,to_device]) 63shard_map.shard_map(psum,mesh=mesh,in_specs=spec,out_specs=spec)(x) 64 65 66defrecv(x:jax.Array,from_device:jax.Device,to_device:jax.Device): 67"""Receives x from a matching send.""" 68assertisinstance(x,jax.Array) 69to_device=jax.local_devices()[0] 70devices=[from_device,to_device] 71psum=lambdax:jax.lax.psum(x,"i") 72mesh=jax.make_mesh((2,),("i",),devices=devices) 73spec=jax.sharding.PartitionSpec(None) 74x=jnp.zeros_like(x) 75x=replicated(x,[from_device,to_device]) 76returnshard_map.shard_map(psum,mesh=mesh,in_specs=spec,out_specs=spec)(x) 77 78 79defallgather(x:float,devices:list[jax.Device])->list[float]: 80"""Performs an AllGather across the provided devices.""" 81n=len(devices) 82mesh=jax.make_mesh((n,),("i",),devices=devices) 83spec=jax.sharding.PartitionSpec('i') 84p=lambdax:jax.lax.all_gather(x,"i",tiled=True) 85f=jax.shard_map(p,mesh=mesh,in_specs=spec,out_specs=spec) 86returnjax.block_until_ready(f(np.array([x]*len(devices)))).addressable_shards[0].data 87 88 89defmain(_:Sequence[str])->None: 90# Parse command line arguments and initialize multi-controller JAX. 91jax.config.update("jax_enable_recoverability",True) 92jax.distributed.initialize(coordinator_address="localhost:8000", 93process_id=_PROCESS_ID.value, 94num_processes=_NUM_PROCESSES.value, 95local_device_ids=[_PROCESS_ID.value], 96heartbeat_timeout_seconds=10) 97print(f'{jax.devices()=}') 98print(f'{jax.local_devices()=}') 99100# Initialize the model's weights.101keys=iter(jax.random.split(jax.random.key(seed=42),num=3))102weights=jax.random.normal(next(keys),shape=(1,))103104# We'll learn a trivial linear model: a*x.105defpredict(weights,X):106returnweights*X107108# We'll use mean squared error loss.109defloss(weights,X,Y):110returnjnp.mean((predict(weights,X)-Y)**2)111112# Initialize the (noisy) training data with a=10.113X=jax.random.permutation(next(keys),jnp.arange(-300.,300.))114Y=10*X+jax.random.normal(next(keys),X.shape)115116# Hyperparameters.117loss_and_grad=jax.jit(jax.value_and_grad(loss))118learning_rate=1e-6119device_batch_size=10120121step=0122whileTrue:123try:124withlive_devices(jax.devices())asdevices:125print(f'=== Running step{step} with live devices ={devices} ===')126127# Handle recovering devices. A device is recovering if its step doesn't128# match process 0's step. We assume process 0 never fails.129print('all gathering steps...')130steps=allgather(step,devices)131print(f'{steps=}')132recovering=[dford,sinzip(devices,steps)ifs!=steps[0]]133fordinrecovering:134# Process 0 sends weights and step to the recovering devices.135ifjax.process_index()==0:136print('sending...')137send(weights,jax.devices()[0],d)138send(jnp.array([step]),jax.devices()[0],d)139elifd.process_index==jax.process_index():140print('receiving...')141weights=recv(weights,jax.devices()[0],d)142step=recv(jnp.array([step]),jax.devices()[0],d)[0]143144# Replicate the model weights.145weights=replicated(weights,devices)146147# Shard the batch.148batch_size=device_batch_size*len(devices)149start=(step*batch_size)%len(X)150stop=start+batch_size151X_batch=sharded(X[start:stop],devices)152Y_batch=sharded(Y[start:stop],devices)153154# Compute gradients and update weights.155l,grad=loss_and_grad(weights,X_batch,Y_batch)156new_weights=jax.block_until_ready(weights-learning_rate*grad)157exceptExceptionase:158print(f'Step{step} failed:{e}')159else:160print(f'Step{step} succeeded: loss ={l}')161step+=1162weights=new_weights163164time.sleep(1)165166167if__name__=="__main__":168app.run(main)
Part 3: Implementation Details#
We now take a deep dive into the architecture of multi-controller JAX and thesemantics and implementation oflive_devices. If you’re only interested inwriting fault-tolerant multi-controller JAX programs, the first two parts ofthis article suffice.
The Coordination Service#
When you launch a multi-controller JAX program, the first process (i.e. process0) runs a standalone RPC server called thecoordination service. Moreover,all processes (including process 0) create an RPC client to the coordinationservice. Concretely, thecoordinator_address argument ofjax.distributed.initialize() is the address of the coordination service.This argument lets process 0 know on what address to run the server, and itlets all processes know which address to connect to.
The coordination service implements the multi-controller JAXcontrol plane.For example, it can perform a distributed barrier across all processes, and itimplements a key-value store that processes can use to exchange small amountsof metadata. Note, however, that thedata plane (e.g., all collectiveoperations on program data) is implemented directly between the processes anddoes not involve the coordination service.
One of the most important functionalities of the coordination service is healthchecking. Every process periodically sends a heartbeat to the coordinationservice. If a process fails, it stops sending heartbeats. If the coordinationservice hasn’t received a heartbeat from a process for a while, it assumes theprocess has failed.
This is shown in the interactive visualization below. The coordination serviceis shown at the top and three multi-controller JAX processes are shown at thebottom. Note how the processes periodically send heartbeats to the controller,and the controller keeps track of the health of each process based on when itlast received a heartbeat. Try failing process 2 by clicking the “Fail” button.Observe how the process stops sending heartbeats and the coordination serviceeventually considers the process dead.
By default, when the coordination service detects that a process has failed, itsends a message to all other processes requesting that they self-terminate. Inother words, all processes in a multi-controller JAX programshare fate.Again fail process 2 in the visualization below by clicking the “Fail” buttonand observe how the coordination service notifies the other processes to fail.
This fate sharing means that multi-controller JAX programs are not at allfault-tolerant. They are fault-intolerant. To enable fault-tolerance, weneed to do two things:
First, we need to remove fate sharing and allow processes to continueexecuting even when a peer process has died. This can be enabled using the
jax_enable_recoverabilityoption, as described inPart 1: Fault Tolerance Basics. We’llassume that this option is set.Second, we need to provide an API that processes can use to learn whichprocesses are alive and which have failed. This is the
live_devicesAPIintroduced inPart 1: Fault Tolerance Basics.
There is a surprising amount of technical depth and subtlety in implementingthelive_devices API. We’ll walk through the design and implementation ofthe API step-by-step. We’ll begin by introducing a simplerlive_processesAPI and slowly improve it until we arrive at thelive_devices API.
Live Processes#
Let’s try to design a new hypothetical JAX API:jax.live_processes. As thename suggests, we wantjax.live_processes() to return the set of allcurrently alive processes. Here is a naive but (as we’ll see momentarily)incorrect implementation. When a process callsjax.live_processes(), itsends an RPC request to the coordination service. Remember that thecoordination service already uses heartbeats to keep track of which processesare dead and which are alive, so when it receives ajax.live_processesrequest, it responds with the set of processes it thinks are alive.
This is illustrated below. Below each process is a “Call live_processes”button. You can click this button to make the process calljax.live_processes. Note how the coordination service replies to alive_processess request with the set of alive processes. Fail process 2 byclicking the “Fail” button and see how it affects later calls tojax.live_processes.
This naive implementation is simple but incorrect. It is crucial that allprocesses in a multi-controller JAX job execute the same instructions in thesame order. If the processes start to diverge, by executing different codepaths in the JAX program, the job will behave erratically. Most likely, it willcrash or hang or produce garbage values, and most certainly it will be veryhard to reason about.
Our naive implementation ofjax.live_processes can very easily lead todivergence. For example, consider a multi-controller JAX job with threeprocesses. If process 0 and 1 both calljax.live_processes around the sametime that process 2 fails, the coordination service might report to process 0that all processes are alive but report to process 1 that only processes 0 and1 are alive. Try to produce this scenario in the visualization below:
If processes disagree on which processes are alive, they will almost certainlydiverge. Thankfully, we can avoid this divergence by augmentingjax.live_processes with barrier semantics.
Barrier Semantics#
Let’s change the implementation ofjax.live_processes so that when thecoordination service receives ajax.live_processes() request, it does notreply right away. Instead, the coordination service only replies onceeverylive process has calledjax.live_processes(). Once every alive process hasentered thejax.live_processess() barrier, the coordination service returnsthe set of live processes. Crucially, the coordination service returns thesame set of live processes to all processes, which prevents the processesfrom diverging.
This is illustrated below. Note that coordination server now keeps track ofwhich devices are in thelive_processes barrier. Try callinglive_processes from every process. Notice how the coordination servicedoesn’t respond until every process has entered the barrier. Then fail process2 and calllive_processes from process 0 and process 1.
Formal Semantics#
Distributed systems are notoriously complex. Machines can fail at arbitrarytimes, and network messages can be dropped, delayed, and reordered. In thissection, we introduce a formal semantics of thejax.live_processes API tohelp tame this complexity. Thinking rigorously about the semantics ofjax.live_processes will help us understand the behavior of the API even inpathological executions.
We’ll base the formal semantics ofjax.live_processes onlinearizability: a popular formalism used to define the semantics of manydistributed APIs. Concretely, we model our distributed system as a number ofprocesses. Each process serially performs a number of events. There are fourtypes of events:
A process canstart (👶). We’ll assume that when a process starts, itconnects to the coordination service, so the coordination service is awarethat is has started.
A process canfail (💀). Unlike starting, the coordination service maynot immediately be aware that a process has failed.
A process cansend a
jax.live_processesrequest to the coordinationservice.A process canreceive a reply to a
jax.live_processesrequest fromthe coordination service.
Below is a diagram of an execution of three processes: 0, 1, and 2. Timeprogresses from left to right. First, all three processes start. This is shownwith the baby emojis. Then all three processes sendjax.live_processesrequests to the coordination service. This is shown as the start of the thickcolored regions. Later, all three processes receive a reply from thecoordination service with0,1,2 as the set of live devices.
In this simple execution, it is clear thatjax.live_processes is behavingcorrectly. We can formalize this intuition with the following formal semantics.
Attention
An execution is valid if wheneverjax.live_processes returns a setPof live processes, there exists an instantaneous moment in time at whichevery process inP was in thelive_processes barrier and every otherprocess was dead. An implementation oflive_processes is correct ifit only allows for valid executions.
Later, we will amend these formal semantics to cover some subtle corner cases,but assume this simplified semantics for now.
In the example above,live_processes returns0,1,2. In thevisualization below, we show that there does exist an instantaneous moment oftime in which processes 0, 1, and 2 are all in the barrier and all otherprocesses (there are none) are dead. The moment in time is drawn as a verticalred bar.
There is nothing special about the specific moment in time we chose in thevisualization above. All that’s important is thatthere exists some moment intime where all processes inP are in the barrier and all other processes aredead. There are many moments in time that satisfy this property, as shownbelow.
In the next example, processes 0 and 1 start, calljax.live_devices, andreceive0,1 as a reply. Process 2 is dead throughout the execution.
This is a valid execution under our formal semantics because there exists amoment a time in which processes 0 and 1 are in the barrier and process 2 isdead.
In the following execution, process 0 callsjax.live_processes and receivesa reply of0. Process 1 callsjax.live_processes, but dies beforereceiving a reply.
Is this a valid execution? Yes. There exists a moment in time at which process0 is in the barrier and process 1 is dead, as shown below. Even though process1 calledjax.live_processes, it is not guaranteed that process 1 will beincluded in the coordination service’s response.
For example, process 1’sjax.live_processes request may have been droppedby the network and never received by the coordination service. So from thecoordination service’s perspective, process 1 is thoroughly dead and never evenentered thelive_processes barrier.
What about the same exact execution, except that process 0 now receives thereply0,1 from the coordination service?
Again, this is a valid execution, as witnessed below. Intuitively, thecoordination service could have receivedjax.live_processes requests fromboth processes 0 and 1 and sent the reply0,1 to both. While this reply wasin the network, process 1 failed. Thus, even though process 1 is dead whenprocess 0 receives a reply, the execution is still valid.
This point bears repeating. Ifjax.live_processes returns a setP ofprocesses, it does not mean that all processes inP arecurrently aliveand all other processes arecurrently dead. It only means thatthere existeda point in time when this was true.
In the following execution, process 1 callsjax.live_processes and fails.Later, process 0 starts, callsjax.live_processes, and receives0,1 asa reply.
Using the formal semantics described thus far, this isnot a valid execution.There is never a point in time where process 0 and 1 are both alive. However,thisshould be a valid execution.
The reason has to do with the unavoidable fact that in a distributed system, itis impossible to detect failures with 100% accuracy. If the coordinationservice hasn’t received heartbeats from a process in a while, it considers theprocess dead. But, the coordination service cannot determine with 100%certainty when the process died or if the process is actually dead at all.Maybe the process died a long time ago, or maybe it died very recently, ormaybe it is alive but on the other side of a network partition.
Let’s return to the execution above for a concrete example. Imagine thecoordination service successfully received process 1’slive_processesrequest. Then, process 1 failed but the coordination service didn’t detect thefailure immediately. In the meantime, the coordination service received process0’slive_processes request. At this point, the coordination service thoughtboth processes were alive and saw that both processes were in the barrier, soit naturally returned0,1 to both processes (though only process 0 receivedthe reply because process 1 was dead).
The coordination service thought process 1 was alive when it was dead. Andsometimes the coordination service might think a process is dead when it isalive. Though not ideal, we need to accommodate executions like this becausethey are unavoidable.
We amend our formal semantics and allow ourselves to move a failure eitherearlier or later in time, though we cannot move a failure past a differentevent from the same process. Intuitively, we can move a failure from when itactually happened to the point in time when the coordination service thought ithappened. Continuing the example above, we can delay the failure of process 1to create a moment in time in which both processes 0 and 1 are in the barrier,witnessing the fact that the execution is valid.
Consider a similar execution below.
As is, there is no moment in time in which process 0 is alive and process 1 isdead. However, if we move the failure of process 1 leftwards, there is. Howmight such an execution arise? Imagine process 1 is partitioned from thecoordination service. The coordination service doesn’t receive any messagesfrom process 1, including its heartbeats. This leads the coordination serviceto conclude that process 1 is dead, even though it isn’t. Then, thecoordination service receives process 0’slive_processes request andresponds with0.
We cannot move a process failure past the process’ other events, however. Forexample, the following execution isinvalid because no matter where we movethe failure of process 1, there is never a moment in time where both processesare in the barrier.
With these formal semantics, we can make sense of even complex executions. Forexample, consider the following execution.
After moving some process failures, we see the execution is valid.
The following execution, on the other hand, is invalid.
Atomicity#
Equipped withjax.live_processes, let’s try to write some fault-tolerantmulti-controller JAX code.
step=0whileTrue:# Get the devices on all live processes.procs=jax.live_processes()devices=[dfordinjax.devices()ifd.process_indexinprocs]# Shard array x over these devices.mesh=jax.make_mesh((len(devices),),("i",),devices=devices)spec=jax.sharding.PartitionSpec("i")sharding=jax.sharding.NamedSharding(mesh,spec)x=jax.make_array_from_process_local_data(sharding,np.ones(1))# Try to perform a jnp.sum.try:print(jnp.sum(x))except:# jnp.sum failed.passelse:# jnp.sum succeeded.step+=1
The code repeatedly
calls
jax.live_processesto learn which processes are alive,computes the set of devices on the healthy processes,
shards an array across these healthy devices,
performs a
jnp.sum(i.e. AllReduce) on the array, andincrements
stepif thejnp.sumsucceeds.
This codelooks correct, but it has a very subtle bug. Assume thejnp.sumis being performed across a set of processesP. If one (or more) of theprocesses inP fails during the execution of thejnp.sum, thenjnp.sum can behave differently on different processes. Some processes inP might seejnp.sum return the correct result. Other processes mightseejnp.sum raise an exception. Others might seejnp.sum return anincorrect result.
Warning
If a process fails during a collective operation, the operation may behavedifferently on different processes.
This means that the processes executing the code example above might diverge.Some might incrementstep, and some might not. In the trivial code exampleabove, this divergence is benign, but in a real program, the divergence wouldlikely lead to a crash, a deadlock, or garbage outputs. For example, if amulti-controller JAX program is training a model with data parallelism andstarts to diverge, some processes might roll back their model weights to aprevious checkpoint while others continue training, leading to a“franken-model” where nobody agrees on what the model weights are supposed tobe.
To write fault-tolerant code that does not diverge, we wantatomicity. Whenexecuting a block of code (like thejnp.sum above), we either wanteveryprocess to run the code successfully, orevery process to learn that the codefailed to execute successfully. We don’t want some processes succeeding andothers failing.
Thankfully, we can achieve atomicity with a very simple trick: calllive_processes twice, once before a code block and once after. If all theprocesses that were alive before the block are also alive after the block, thenthe code block executed successfully on all live processes. On the other hand,if any process died, then all remaining processes can agree the code blockfailed to execute properly. Here’s a sketch of what that might look like:
# Get the set of live processes before the code block.procs_before=jax.live_processes()# Execute the code block....# Get the set of live processes after the code blockprocs_after=jax.live_processes()ifprocs_before==procs_after:# The code block executed successfully on all processes in# procs_before.passelse:# The code block did not execute successfully. All processes will# agree it failed.pass
The code above should give you a rough idea of how to use two calls tolive_processes to achieve atomicity, but there are still a handful of smallissues we need to address before it is fully correct. For example,
What if the code block throws an exception? We need to catch the exceptionand still call
live_processessthe second time and then re-raise theexception.What if a process fails after the first call to
live_processesandrecovers before the second call? Wouldn’t the code block fail but theprocesses before and after be the same? Every time a process starts, itgenerates a randomincarnation id. In addition to checking that the setof processes hasn’t changed, we also check that their incarnation ids haven’tchanged.What if a process recovers and its first call to
live_processesmatchesup with a different process’ second call tolive_processes? Couldn’t thislead to a deadlock? Yes. We can avoid the problem by only callinglive_processesat a single program point. We can be clever and use asingle call tolive_processesfor two purposes. It can be used to checkthat the set of processes hasn’t changed since the previous call tolive_processes, and it can be used to generate the set of live processesthat should be used the next time the atomic code block is executed.
All these details are handled and abstracted away by thejax.live_devicesAPI introduced inPart 1: Fault Tolerance Basics.jax.live_devices is a context manager thatguarantees the atomic execution of a block of code. In the code snippet below,devices is a list of the devices on all live processes. The code blockA will execute atomically across these processes. That is, either everyprocess will see the code raise an exception (branchB) or every processwill see the code succeed (branchC).
try:withlive_devices()asdevices:pass# AexceptExceptionase:pass# Belse:pass# C
Cancelling Collectives#
As mentioned inCancelling Collectives, if a process participating in acollective fails, then the other participating processes get stuck forever. Weneed to explicitly cancel these collectives to allow the alive participants tomake progress. While thelive_devices API is supported on all JAX backends(i.e. CPU, GPU, TPU), cancelling collectives is only supported by the GPUbackend. Here, we briefly explain some of the implementation details behindcollective cancelling.
The GPU backend implements collectives usingNCCL, NVIDIA’s collectivecommunication library. When a set of processes wants to perform a collective,they form aNCCL communicator. Processes can then repeatedly performcollectives using this communicator. Creating a communicator is expensive—itrequires network communication—so the JAX backend caches communicators keyedby the set of participating processes and their incarnation ids.
Internally, a JAX client polls the coordination service for the current statusof every process. If a client ever detects that a process is dead or hasrestarted with a new incarnation id, then the client aborts all communicatorswith the failed incarnation id in its cache key.
