Introduction to multi-controller JAX (aka multi-process/multi-host JAX)
Contents
Introduction to multi-controller JAX (aka multi-process/multi-host JAX)#
By reading this tutorial, you’ll learn how to scale JAX computations to moredevices than can fit in a single host machine, e.g. when running on a GPUcluster, Cloud TPU pod, or multiple CPU-only machines.
The main idea
Run multiple Python processes, which we sometimes call “controllers.” Wecan run one (or more) process per host machine.
Initialize the cluster with
jax.distributed.initialize().A
jax.Arraycan span all processes, and if each process appliesthe same JAX function to it, it’s like programming against one big device.Use the sameunified sharding mechanism as insingle-controller JAX to control how data is distributed and computation isparallelized. XLA automatically exploits high-speed networking links like TPUICI or NVLink between hosts when available, and otherwise uses available hostnetworking (e.g. Ethernet, InfiniBand).
All processes (usually) run the same Python script. You write this Pythoncode almost exactly the same as you would for a single process — just runmultiple instances of it and JAX takes care of the rest. In other words,except for array creation, you can write your JAX code as if there were onegiant machine with all devices attached to it.
This tutorial assumes you’ve readDistributed arrays and automaticparallelization, which is about single-controller JAX.

Illustration of a multi-host TPU pod. Each host in the pod (green) is attachedvia PCI to a board of four TPU chips (blue). The TPUs chips themselves areconnected via high-speed inter-chip interconnects (ICI). JAX Python code runs oneach host, e.g. via ssh. The JAX processes on each host are aware of each other,allowing you to orchestrate computation across the entire pods’ worth of chips.The principle is the same for GPU, CPU, and other platforms with JAX support!#
Toy example#
Before we define terms and walk through the details, here’s a toy example:making a process-spanningjax.Array of values and applyingjax.numpy functions to it.
# call this file toy.py, to be run in each process simultaneouslyimportjaximportjax.numpyasjnpfromjax.shardingimportNamedSharding,PartitionSpecasPimportnumpyasnp# in this example, get multi-process parameters from sys.argvimportsysproc_id=int(sys.argv[1])num_procs=int(sys.argv[2])# initialize the distributed systemjax.distributed.initialize('localhost:10000',num_procs,proc_id)# this example assumes 8 devices totalassertjax.device_count()==8# make a 2D mesh that refers to devices from all processesmesh=jax.make_mesh((4,2),('i','j'))# create some toy dataglobal_data=np.arange(32).reshape((4,8))# make a process- and device-spanning array from our toy datasharding=NamedSharding(mesh,P('i','j'))global_array=jax.device_put(global_data,sharding)assertglobal_array.shape==global_data.shape# each process has different shards of the global arrayforshardinglobal_array.addressable_shards:print(f"device{shard.device} has local data{shard.data}")# apply a simple computation, automatically partitionedglobal_result=jnp.sum(jnp.sin(global_array))print(f'process={proc_id} got result:{global_result}')
Here,mesh contains devices from all processes. We use it to createglobal_array, logically a single shared array, stored distributed acrossdevices from all processes.
Every process must apply the same operations, in the same order, toglobal_array. XLA automatically partitions those computations, for exampleinserting communication collectives to compute thejnp.sum over the fullarray. We can print the final result because its value is replicated acrossprocesses.
We can run this code locally on CPU, e.g. using 4 processes and 2 CPU devicesper process:
exportJAX_NUM_CPU_DEVICES=2num_processes=4range=$(seq0$(($num_processes-1)))foriin$range;dopythontoy.py$i$num_processes>/tmp/toy_$i.out&donewaitforiin$range;doecho"=================== process$i output ==================="cat/tmp/toy_$i.outechodone
Outputs:
=================== process 0 output ===================device TFRT_CPU_0 has local data [[0 1 2 3]]device TFRT_CPU_1 has local data [[4 5 6 7]]process=0 got result: -0.12398731708526611=================== process 1 output ===================device TFRT_CPU_131072 has local data [[ 8 9 10 11]]device TFRT_CPU_131073 has local data [[12 13 14 15]]process=1 got result: -0.12398731708526611=================== process 2 output ===================device TFRT_CPU_262144 has local data [[16 17 18 19]]device TFRT_CPU_262145 has local data [[20 21 22 23]]process=2 got result: -0.12398731708526611=================== process 3 output ===================device TFRT_CPU_393216 has local data [[24 25 26 27]]device TFRT_CPU_393217 has local data [[28 29 30 31]]process=3 got result: -0.12398731708526611
This might not look so different from single-controller JAX code, and in fact,this is exactly how you’d write the single-controller version of the sameprogram! (We don’t technically need to calljax.distributed.initialize()for single-controller, but it doesn’t hurt.) Let’s run the same code from asingle process:
JAX_NUM_CPU_DEVICES=8 python toy.py 0 1
Outputs:
device TFRT_CPU_0 has local data [[0 1 2 3]]device TFRT_CPU_1 has local data [[4 5 6 7]]device TFRT_CPU_2 has local data [[ 8 9 10 11]]device TFRT_CPU_3 has local data [[12 13 14 15]]device TFRT_CPU_4 has local data [[16 17 18 19]]device TFRT_CPU_5 has local data [[20 21 22 23]]device TFRT_CPU_6 has local data [[24 25 26 27]]device TFRT_CPU_7 has local data [[28 29 30 31]]process=0 got result: -0.12398731708526611
The data is sharded across eight devices on one process rather than eightdevices across four processes, but otherwise we’re running the same operationsover the same data.
Terminology#
It’s worth pinning down some terminology.
We sometimes call each Python process running JAX computations acontroller,but the two terms are essentially synonymous.
Each process has a set oflocal devices, meaning it can transfer data to andfrom those devices’ memories and run computation on those devices withoutinvolving any other processes. The local devices are usually physically attachedto the process’s corresponding host, e.g. via PCI. A device can only be local toone process; that is, the local device sets are disjoint. A process’s localdevices can be queried by evaluatingjax.local_devices(). We sometimesuse the termaddressable to mean the same thing as local.

Illustration of how a process/controller and local devices fit into a largermulti-host cluster. The “global devices” are all devices in the cluster.#
The devices across all processes are called theglobal devices. The list ofglobal devices is queried byjax.devices(). That list of all devices ispopulated by runningjax.distributed.initialize() on all processes, whichsets up a simple distributed system connecting the processes.
We often use the termsglobal andlocal to describe process-spanning andprocess-local concepts in general. For example, a “local array” could be a numpyarray that’s only visible to a single process, vs. a JAX “global array” isconceptually visible to all processes.
Setting up multiple JAX processes#
In practice, setting up multiple JAX processes looks a bit different from thetoy example, which is run from a single host machine. We usually launch eachprocess on a separate host, or have multiple hosts with multiple processes each.We can do that directly usingssh, or with a cluster manager like Slurm orKubernetes. In any case,you must manually run your JAX program on eachhost! JAX doesn’t automatically start multiple processes from a single programinvocation.
However they’re launched, the Python processes need to runjax.distributed.initialize(). When using Slurm, Kubernetes, or any CloudTPU deployment, we can runjax.distributed.initialize() with no argumentsas they’re automatically populated. Initializing the system means we can runjax.devices() to report all devices across all processes.
Warning
jax.distributed.initialize() must be called before runningjax.devices(),jax.local_devices(), or running any computationson devices (e.g. withjax.numpy). Otherwise the JAX process won’t beaware of any non-local devices. (Usingjax.config() or othernon-device-accessing functionality is ok.)jax.distributed.initialize()will raise an error if you accidentally call it after accessing any devices.
GPU Example#
We can run multi-controller JAX on a cluster ofGPU machines.For example, after creating four VMs on Google Cloud with two GPUs per VM, wecan run the following JAX program on every VM. In this example, we providearguments tojax.distributed.initialize() explicitly. The coordinatoraddress, process id, and number of processes are read from the command line.
# In file gpu_example.py...importjaximportsys# Get the coordinator_address, process_id, and num_processes from the command line.coord_addr=sys.argv[1]proc_id=int(sys.argv[2])num_procs=int(sys.argv[3])# Initialize the GPU machines.jax.distributed.initialize(coordinator_address=coord_addr,num_processes=num_procs,process_id=proc_id)print("process id =",jax.process_index())print("global devices =",jax.devices())print("local devices =",jax.local_devices())
For example, if the first VM has address192.168.0.1, then you would runpython3gpu_example.py192.168.0.1:800004 on the first VM,python3gpu_example.py192.168.0.1:800014 on the second VM, and so on. After runningthe JAX program on all four VMs, the first process prints the following.
process id = 0global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]local devices = [CudaDevice(id=0), CudaDevice(id=1)]
The process successfully sees all eight GPUs as global devices, as well as itstwo local devices. Similarly, the second process prints the following.
process id = 1global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]local devices = [CudaDevice(id=2), CudaDevice(id=3)]
This VM sees the same global devices, but has a different set of local devices.
TPU Example#
As another example, we can run onCloud TPU. After creating av5litepod-16 (which has 4 host machines), we might want to test that we canconnect the processes and list all devices:
$ TPU_NAME=jax-demo$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \ | grep externalIp | cut -d: -f2)$ cat << EOF > demo.pyimport jaxjax.distributed.initialize()if jax.process_index() == 0: print(jax.devices())EOF$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c 'scp demo.py $0:ssh $0 "pip -q install -U jax[tpu]"ssh $0 "python demo.py" '
Here we’re usingxargs to run multiplessh commands in parallel, each onerunning the same Python program on one of the TPU host machines. In the Pythoncode, we usejax.process_index() to print only on one process. Here’swhat it prints:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)]
Woohoo, look at all those TPU cores!
Kubernetes Example#
Running multi-controller JAX on a Kubernetes cluster is almost identical in spirit to the GPU and TPU examples above: every pod runs the same Python program, JAX discovers its peers, and the cluster behaves like one giant machine.
Container image - start from a JAX-enabled image, e.g. one of the public JAX AI images on Google Artifact Registry (TPU /GPU) or NVIDIA (NGC /JAX-Toolbox).
Workload type - use either aJobSet or anindexed Job. Each replica corresponds to one JAX process.
Service Account - JAX needs permission to list the pods that belong to the job so that processes discover their peers. A minimal RBAC setup is provided inexamples/k8s/svc-acct.yaml.
Below is aminimal JobSet that launches two replicas. Replace the placeholders -image, GPU count, and any private registry secrets - with values that match your environment.
apiVersion:jobset.x-k8s.io/v1alpha2kind:JobSetmetadata:name:jaxjobspec:replicatedJobs:-name:workerstemplate:spec:parallelism:2completions:2backoffLimit:0template:spec:serviceAccountName:jax-job-sa# kubectl apply -f svc-acct.yamlrestartPolicy:NeverimagePullSecrets:# https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/-name:nullcontainers:-name:mainimage:null# e.g. ghcr.io/nvidia/jax:jaximagePullPolicy:Alwaysresources:limits:cpu:1# https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/nvidia.com/gpu:nullcommand:-pythonargs:--c-|import jaxjax.distributed.initialize()print(jax.devices())print(jax.local_devices())assert jax.process_count() > 1assert len(jax.devices()) > len(jax.local_devices())
Apply the manifest and watch the pods complete:
$kubectlapply-fexample.yaml$kubectlgetpods-ljobset.sigs.k8s.io/jobset-name=jaxjobNAMEREADYSTATUSRESTARTSAGEjaxjob-workers-0-0-xpx8l0/1Completed08m32sjaxjob-workers-0-1-ddkq80/1Completed08m32s
When the job finishes, inspect the logs to confirm that every process saw all accelerators:
$kubectllogs-ljobset.sigs.k8s.io/jobset-name=jaxjob[CudaDevice(id=0),CudaDevice(id=1)][CudaDevice(id=0)][CudaDevice(id=0),CudaDevice(id=1)][CudaDevice(id=1)]
Every pod should have the same set of global devices and a different set of local devices. At this point, you can replace the inline script with your real JAX program.
Once the processes are set up, we can start building globaljax.Arraysand running computations. The remaining Python code examples in this tutorialare meant to be run on all processes simultaneously, after runningjax.distributed.initialize().
Meshes, shardings, and computations can span processes and hosts#
Programming multiple processes from JAX usually looks just like programming asingle process, just with more devices! The main exceptions to this are arounddata coming in or out of JAX, e.g. when loading from external data sources.We’ll first go over the basics of multi-process computations here, which largelylook the same as their single-process counterparts. The next section goes oversome data loading fundamentals, i.e. how to create JAX Arrays from non-JAXsources.
Recall ajax.sharding.Mesh pairs an array ofjax.Devices witha sequence of names, with one name per array axis. By creating aMesh usingdevices from multiple processes, then using that mesh in ajax.sharding.Sharding, we can constructjax.Arrays shardedover devices from multiple processes.
Here’s an example that directly constructs aMesh usingjax.devices()to get devices from all processes:
fromjax.shardingimportMeshmesh=Mesh(jax.devices(),('a',))# in this case, the same asmesh=jax.make_mesh((jax.device_count(),),('a',))# use this in practice
You should probably use thejax.make_mesh() helper in practice, not onlybecause it’s simpler but also because it can choose more performant deviceorderings automatically, but we’re spelling it out here. By default it includesall devices across processes, just likejax.devices().
Once we have a mesh, we can shard arrays over it. There are a few ways toefficiently build process-spanning arrays, detailed in the next section, but fornow we’ll stick tojax.device_put for simplicity:
arr=jax.device_put(jnp.ones((32,32)),NamedSharding(mesh,P('a')))ifjax.process_index()==0:jax.debug.visualize_array_sharding(arr)
On process 0, this is printed:
┌───────────────────────┐│ TPU 0 │├───────────────────────┤│ TPU 1 │├───────────────────────┤│ TPU 4 │├───────────────────────┤│ TPU 5 │├───────────────────────┤│ TPU 2 │├───────────────────────┤│ TPU 3 │├───────────────────────┤│ TPU 6 │├───────────────────────┤│ TPU 7 │├───────────────────────┤│ TPU 8 │├───────────────────────┤│ TPU 9 │├───────────────────────┤│ TPU 12 │├───────────────────────┤│ TPU 13 │├───────────────────────┤│ TPU 10 │├───────────────────────┤│ TPU 11 │├───────────────────────┤│ TPU 14 │├───────────────────────┤│ TPU 15 │└───────────────────────┘
Let’s try a slightly more interesting computation!
mesh=jax.make_mesh((jax.device_count()//2,2),('a','b'))defdevice_put(x,spec):returnjax.device_put(x,NamedSharding(mesh,spec))# construct global arrays by sharding over the global meshx=device_put(jnp.ones((4096,2048)),P('a','b'))y=device_put(jnp.ones((2048,4096)),P('b',None))# run a distributed matmulz=jax.nn.relu(x@y)# inspect the sharding of the resultifjax.process_index()==0:jax.debug.visualize_array_sharding(z)print()print(z.sharding)
On process 0, this is printed:
┌───────────────────────┐│ TPU 0,1 │├───────────────────────┤│ TPU 4,5 │├───────────────────────┤│ TPU 8,9 │├───────────────────────┤│ TPU 12,13 │├───────────────────────┤│ TPU 2,3 │├───────────────────────┤│ TPU 6,7 │├───────────────────────┤│ TPU 10,11 │├───────────────────────┤│ TPU 14,15 │└───────────────────────┘NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device)Here, just from evaluatingx@y on all processes, XLA is automaticallygenerating and running a distributed matrix multiplication. The result issharded against the mesh likeP('a',None), since in this case the matmulincluded apsum over the'b' axis.
Warning
When applying JAX computations to process-spanning arrays, to avoid deadlocksand hangs,it’s crucial that all processes with participating devices run thesame computation in the same order. That’s because the computation mayinvolve collective communication barriers. If a device over which an array issharded does not join in the collective because its controller didn’t issue thesame computation, the other devices are left waiting. For example, if only thefirst three processes evaluatedx@y, while the last process evaluatedy@x, the computation would likely hang indefinitely. This assumption,computations on process-spanning arrays are run on all participating processesin the same order, is mostly unchecked.
So the easiest way to avoid deadlocks in multi-process JAX is to run the samePython code on every process, and beware of any control flow that depends onjax.process_index() and includes communication.
If a process-spanning array is sharded over devices on different processes, itis an error to perform operations on the array that require the data to beavailable locally to a process, like printing. For example, if we runprint(z)in the preceding example, we see
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
To print the full array value, we must first ensure it’s replicated overprocesses (but not necessarily over each process’s local devices), e.g. usingjax.device_put. In the above example, we can write at the end:
w=device_put(z,P(None,None))ifjax.process_index()==0:print(w)
Be careful not to write thejax.device_put() under theifprocess_index()==0, because that would lead to a deadlock as only process 0 initiates thecollective communication and waits indefinitely for the other processes.Thejax.experimental.multihost_utils module has some functions thatmake it easier to process globaljax.Arrays (e.g.,jax.experimental.multihost_utils.process_allgather()).
Alternatively, to print or otherwise perform Python operations on onlyprocess-local data, we can accessz.addressable_shards. Accessing thatattribute does not require any communication, so any subset of processes can doit without needing the others. That attribute is not available under ajax.jit().
Making process-spanning arrays from external data#
There are three main ways to create process-spanningjax.Arrays fromexternal data sources (e.g. numpy arrays from a data loader):
Create or load the full array on all processes, then shard onto devices using
jax.device_put();Create or load on each process an array representing just the data that willbe locally sharded and stored on that process’s devices, then shard ontodevices using
jax.make_array_from_process_local_data();Create or load on each process’s devices separate arrays, each representingthe data to be stored on that device, then assemble them without any datamovement using
jax.make_array_from_single_device_arrays().
The latter two are most often used in practice, since it’s often too expensiveto materialize the full global data in every process.
The toy example above usesjax.device_put().
jax.make_array_from_process_local_data() is often used for distributed dataloading. It’s not as general asjax.make_array_from_single_device_arrays(),because it doesn’t directly specify which slice of the process-local data goeson each local device. This is convenient when loading data-parallel batches,because it doesn’t matter exactly which microbatch goes on each device. Forexample:
# target (micro)batch size across the whole clusterbatch_size=1024# how many examples each process should load per batchper_process_batch_size=batch_size//jax.process_count()# how many examples each device will process per batchper_device_batch_size=batch_size//jax.device_count()# make a data-parallel mesh and shardingmesh=jax.make_mesh((jax.device_count(),),('batch'))sharding=NamedSharding(mesh,P('batch'))# our "data loader". each process loads a different set of "examples".process_batch=np.random.rand(per_process_batch_size,2048,42)# assemble a global array containing the per-process batches from all processesglobal_batch=jax.make_array_from_process_local_data(sharding,process_batch)# sanity check that everything got sharded correctlyassertglobal_batch.shape[0]==batch_sizeassertprocess_batch.shape[0]==per_process_batch_sizeassertglobal_batch.addressable_shards[0].data.shape[0]==per_device_batch_size
jax.make_array_from_single_device_arrays() is the most general way tobuild a process-spanning array. It’s often used after performingjax.device_put()s to send each device its required data. This is thelowest-level option, since all data movement is performed manually (via e.g.jax.device_put()). Here’s an example:
shape=(jax.process_count(),jax.local_device_count())mesh=jax.make_mesh(shape,('i','j'))sharding=NamedSharding(mesh,P('i','j'))# manually create per-device data equivalent to np.arange(jax.device_count())# i.e. each device will get a single scalar value from 0..Nlocal_arrays=[jax.device_put(jnp.array([[jax.process_index()*jax.local_device_count()+i]]),device)fori,deviceinenumerate(jax.local_devices())]# assemble a global array from the local_arrays across all processesglobal_array=jax.make_array_from_single_device_arrays(shape=shape,sharding=sharding,arrays=local_arrays)# sanity checkassert(np.all(jax.experimental.multihost_utils.process_allgather(global_array)==np.arange(jax.device_count()).reshape(global_array.shape)))
