Rate this Page

Writing Distributed Applications with PyTorch#

Created On: Oct 06, 2017 | Last Updated: Sep 05, 2025 | Last Verified: Nov 05, 2024

Author:Séb Arnold

Note

edit View and edit this tutorial ingithub.

Prerequisites:

In this short tutorial, we will be going over the distributed packageof PyTorch. We’ll see how to set up the distributed setting, use thedifferent communication strategies, and go over some of the internals ofthe package.

Setup#

The distributed package included in PyTorch (i.e.,torch.distributed) enables researchers and practitioners to easilyparallelize their computations across processes and clusters ofmachines. To do so, it leverages message passing semanticsallowing each process to communicate data to any of the other processes.As opposed to the multiprocessing (torch.multiprocessing) package,processes can use different communication backends and are notrestricted to being executed on the same machine.

In order to get started we need the ability to run multiple processessimultaneously. If you have access to compute cluster you should checkwith your local sysadmin or use your favorite coordination tool (e.g.,pdsh,clustershell, orslurm). For the purpose of thistutorial, we will use a single machine and spawn multiple processes usingthe following template.

"""run.py:"""#!/usr/bin/env pythonimportosimportsysimporttorchimporttorch.distributedasdistimporttorch.multiprocessingasmpdefrun(rank,size):""" Distributed function to be implemented later. """passdefinit_process(rank,size,fn,backend='gloo'):""" Initialize the distributed environment. """os.environ['MASTER_ADDR']='127.0.0.1'os.environ['MASTER_PORT']='29500'dist.init_process_group(backend,rank=rank,world_size=size)fn(rank,size)if__name__=="__main__":world_size=2processes=[]if"google.colab"insys.modules:print("Running in Google Colab")mp.get_context("spawn")else:mp.set_start_method("spawn")forrankinrange(world_size):p=mp.Process(target=init_process,args=(rank,world_size,run))p.start()processes.append(p)forpinprocesses:p.join()

The above script spawns two processes who will each setup thedistributed environment, initialize the process group(dist.init_process_group), and finally execute the givenrunfunction.

Let’s have a look at theinit_process function. It ensures thatevery process will be able to coordinate through a master, using thesame ip address and port. Note that we used thegloo backend butother backends are available. (c.f.Section 5.1) We will go over the magichappening indist.init_process_group at the end of this tutorial,but it essentially allows processes to communicate with each other bysharing their locations.

Point-to-Point Communication#

Send and Recv

Send and Recv#

A transfer of data from one process to another is called apoint-to-point communication. These are achieved through thesendandrecv functions or theirimmediate counter-parts,isend andirecv.

"""Blocking point-to-point communication."""defrun(rank,size):tensor=torch.zeros(1)ifrank==0:tensor+=1# Send the tensor to process 1dist.send(tensor=tensor,dst=1)else:# Receive tensor from process 0dist.recv(tensor=tensor,src=0)print('Rank ',rank,' has data ',tensor[0])

In the above example, both processes start with a zero tensor, thenprocess 0 increments the tensor and sends it to process 1 so that theyboth end up with 1.0. Notice that process 1 needs to allocate memory inorder to store the data it will receive.

Also notice thatsend/recv areblocking: both processes blockuntil the communication is completed. On the other hand immediates arenon-blocking; the script continues its execution and the methodsreturn aWork object upon which we can choose towait().

"""Non-blocking point-to-point communication."""defrun(rank,size):tensor=torch.zeros(1)req=Noneifrank==0:tensor+=1# Send the tensor to process 1req=dist.isend(tensor=tensor,dst=1)print('Rank 0 started sending')else:# Receive tensor from process 0req=dist.irecv(tensor=tensor,src=0)print('Rank 1 started receiving')req.wait()print('Rank ',rank,' has data ',tensor[0])

When using immediates we have to be careful about how we use the sent and received tensors.Since we do not know when the data will be communicated to the other process,we should not modify the sent tensor nor access the received tensor beforereq.wait() has completed.In other words,

  • writing totensor afterdist.isend() will result in undefined behaviour.

  • reading fromtensor afterdist.irecv() will result in undefinedbehaviour, untilreq.wait() has been executed.

However, afterreq.wait()has been executed we are guaranteed that the communication took place,and that the value stored intensor[0] is 1.0.

Point-to-point communication is useful when we want more fine-grainedcontrol over the communication of our processes. They can be used toimplement fancy algorithms, such as the one used inBaidu’sDeepSpeech orFacebook’s large-scaleexperiments.(c.f.Section 4.1)

Collective Communication#

Scatter

Scatter#

Gather

Gather#

Reduce

Reduce#

All-Reduce

All-Reduce#

Broadcast

Broadcast#

All-Gather

All-Gather#

As opposed to point-to-point communcation, collectives allow forcommunication patterns across all processes in agroup. A group is asubset of all our processes. To create a group, we can pass a list ofranks todist.new_group(group). By default, collectives are executedon all processes, also known as theworld. For example, in orderto obtain the sum of all tensors on all processes, we can use thedist.all_reduce(tensor,op,group) collective.

""" All-Reduce example."""defrun(rank,size):""" Simple collective communication. """group=dist.new_group([0,1])tensor=torch.ones(1)dist.all_reduce(tensor,op=dist.ReduceOp.SUM,group=group)print('Rank ',rank,' has data ',tensor[0])

Since we want the sum of all tensors in the group, we usedist.ReduceOp.SUM as the reduce operator. Generally speaking, anycommutative mathematical operation can be used as an operator.Out-of-the-box, PyTorch comes with many such operators, all working at theelement-wise level:

  • dist.ReduceOp.SUM,

  • dist.ReduceOp.PRODUCT,

  • dist.ReduceOp.MAX,

  • dist.ReduceOp.MIN,

  • dist.ReduceOp.BAND,

  • dist.ReduceOp.BOR,

  • dist.ReduceOp.BXOR,

  • dist.ReduceOp.PREMUL_SUM.

The full list of supported operators ishere.

In addition todist.all_reduce(tensor,op,group), there are many additional collectives currently implemented inPyTorch. Here are a few supported collectives.

  • dist.broadcast(tensor,src,group): Copiestensor fromsrc to all other processes.

  • dist.reduce(tensor,dst,op,group): Appliesop to everytensor and stores the result indst.

  • dist.all_reduce(tensor,op,group): Same as reduce, but theresult is stored in all processes.

  • dist.scatter(tensor,scatter_list,src,group): Copies the\(i^{\text{th}}\) tensorscatter_list[i] to the\(i^{\text{th}}\) process.

  • dist.gather(tensor,gather_list,dst,group): Copiestensorfrom all processes indst.

  • dist.all_gather(tensor_list,tensor,group): Copiestensorfrom all processes totensor_list, on all processes.

  • dist.barrier(group): Blocks all processes ingroup until each one has entered this function.

  • dist.all_to_all(output_tensor_list,input_tensor_list,group): Scatters list of input tensors to all processes ina group and return gathered list of tensors in output list.

The full list of supported collectives can be found by looking at the latest documentation for PyTorch Distributed(link).

Distributed Training#

Note: You can find the example script of this section inthisGitHub repository.

Now that we understand how the distributed module works, let us writesomething useful with it. Our goal will be to replicate thefunctionality ofDistributedDataParallel.Of course, this will be a didactic example and in a real-worldsituation you should use the official, well-tested and well-optimizedversion linked above.

Quite simply we want to implement a distributed version of stochasticgradient descent. Our script will let all processes compute thegradients of their model on their batch of data and then average theirgradients. In order to ensure similar convergence results when changingthe number of processes, we will first have to partition our dataset.(You could also usetorch.utils.data.random_split,instead of the snippet below.)

""" Dataset partitioning helper """classPartition(object):def__init__(self,data,index):self.data=dataself.index=indexdef__len__(self):returnlen(self.index)def__getitem__(self,index):data_idx=self.index[index]returnself.data[data_idx]classDataPartitioner(object):def__init__(self,data,sizes=[0.7,0.2,0.1],seed=1234):self.data=dataself.partitions=[]rng=Random()# from random import Randomrng.seed(seed)data_len=len(data)indexes=[xforxinrange(0,data_len)]rng.shuffle(indexes)forfracinsizes:part_len=int(frac*data_len)self.partitions.append(indexes[0:part_len])indexes=indexes[part_len:]defuse(self,partition):returnPartition(self.data,self.partitions[partition])

With the above snippet, we can now simply partition any dataset usingthe following few lines:

""" Partitioning MNIST """defpartition_dataset():dataset=datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))]))size=dist.get_world_size()bsz=128//sizepartition_sizes=[1.0/sizefor_inrange(size)]partition=DataPartitioner(dataset,partition_sizes)partition=partition.use(dist.get_rank())train_set=torch.utils.data.DataLoader(partition,batch_size=bsz,shuffle=True)returntrain_set,bsz

Assuming we have 2 replicas, then each process will have atrain_setof 60000 / 2 = 30000 samples. We also divide the batch size by thenumber of replicas in order to maintain theoverall batch size of 128.

We can now write our usual forward-backward-optimize training code, andadd a function call to average the gradients of our models. (Thefollowing is largely inspired by the officialPyTorch MNISTexample.)

""" Distributed Synchronous SGD Example """defrun(rank,size):torch.manual_seed(1234)train_set,bsz=partition_dataset()model=Net()optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)num_batches=ceil(len(train_set.dataset)/float(bsz))forepochinrange(10):epoch_loss=0.0fordata,targetintrain_set:optimizer.zero_grad()output=model(data)loss=F.nll_loss(output,target)epoch_loss+=loss.item()loss.backward()average_gradients(model)optimizer.step()print('Rank ',dist.get_rank(),', epoch ',epoch,': ',epoch_loss/num_batches)

It remains to implement theaverage_gradients(model) function, whichsimply takes in a model and averages its gradients across the wholeworld.

""" Gradient averaging. """defaverage_gradients(model):size=float(dist.get_world_size())forparaminmodel.parameters():dist.all_reduce(param.grad.data,op=dist.ReduceOp.SUM)param.grad.data/=size

Et voilà! We successfully implemented distributed synchronous SGD andcould train any model on a large computer cluster.

Note: While the last sentence istechnically true, there arealot more tricks required toimplement a production-level implementation of synchronous SGD. Again,use whathas been tested andoptimized.

Our Own Ring-Allreduce#

As an additional challenge, imagine that we wanted to implementDeepSpeech’s efficient ring allreduce. This is fairly easy to implementusing point-to-point collectives.

""" Implementation of a ring-reduce with addition. """defallreduce(send,recv):rank=dist.get_rank()size=dist.get_world_size()send_buff=send.clone()recv_buff=send.clone()accum=send.clone()left=((rank-1)+size)%sizeright=(rank+1)%sizeforiinrange(size-1):ifi%2==0:# Send send_buffsend_req=dist.isend(send_buff,right)dist.recv(recv_buff,left)accum[:]+=recv_buff[:]else:# Send recv_buffsend_req=dist.isend(recv_buff,right)dist.recv(send_buff,left)accum[:]+=send_buff[:]send_req.wait()recv[:]=accum[:]

In the above script, theallreduce(send,recv) function has aslightly different signature than the ones in PyTorch. It takes arecv tensor and will store the sum of allsend tensors in it. Asan exercise left to the reader, there is still one difference betweenour version and the one in DeepSpeech: their implementation divides thegradient tensor intochunks, so as to optimally utilize thecommunication bandwidth. (Hint:torch.chunk)

Advanced Topics#

We are now ready to discover some of the more advanced functionalitiesoftorch.distributed. Since there is a lot to cover, this section isdivided into two subsections:

  1. Communication Backends: where we learn how to use MPI and Gloo forGPU-GPU communication.

  2. Initialization Methods: where we understand how to best set up theinitial coordination phase indist.init_process_group().

Communication Backends#

One of the most elegant aspects oftorch.distributed is its abilityto abstract and build on top of different backends. As mentioned before,there are multiple backends implemented in PyTorch. These backends can be easily selectedusing theAccelerator API,which provides a interface for working with different accelerator types.Some of the most popular backends are Gloo, NCCL, and MPI. They each have different specifications and tradeoffs, dependingon the desired use case. A comparative table of supported functions canbe foundhere.

Gloo Backend

So far we have made extensive usage of theGloo backend.It is quite handy as a development platform, as it is included inthe pre-compiled PyTorch binaries and works on both Linux (since 0.2)and macOS (since 1.3). It supports all point-to-point and collectiveoperations on CPU, and all collective operations on GPU. Theimplementation of the collective operations for CUDA tensors is not asoptimized as the ones provided by the NCCL backend.

As you have surely noticed, ourdistributed SGD example does not work if you putmodel on the GPU.In order to use multiple GPUs, let us also make the followingmodifications:

  1. Use Accelerator APIdevice_type=torch.accelerator.current_accelerator()

  2. Usetorch.device(f"{device_type}:{rank}")

  3. model=Net()\(\rightarrow\)model=Net().to(device)

  4. Usedata,target=data.to(device),target.to(device)

With these modifications, your model will now train across two GPUs.You can monitor GPU utilization usingwatchnvidia-smi if you are running on NVIDIA hardware.

MPI Backend

The Message Passing Interface (MPI) is a standardized tool from thefield of high-performance computing. It allows to do point-to-point andcollective communications and was the main inspiration for the API oftorch.distributed. Several implementations of MPI exist (e.g.Open-MPI,MVAPICH2,IntelMPI) eachoptimized for different purposes. The advantage of using the MPI backendlies in MPI’s wide availability - and high-level of optimization - onlarge computer clusters.Somerecentimplementations are also able to takeadvantage of CUDA IPC and GPU Direct technologies in order to avoidmemory copies through the CPU.

Unfortunately, PyTorch’s binaries cannot include an MPI implementationand we’ll have to recompile it by hand. Fortunately, this process isfairly simple given that upon compilation, PyTorch will lookby itselffor an available MPI implementation. The following steps install the MPIbackend, by installing PyTorchfromsource.

  1. Create and activate your Anaconda environment, install all thepre-requisites followingtheguide, but donot runpythonsetup.pyinstall yet.

  2. Choose and install your favorite MPI implementation. Note thatenabling CUDA-aware MPI might require some additional steps. In ourcase, we’ll stick to Open-MPIwithout GPU support:condainstall-cconda-forgeopenmpi

  3. Now, go to your cloned PyTorch repo and executepythonsetup.pyinstall.

In order to test our newly installed backend, a few modifications arerequired.

  1. Replace the content underif__name__=='__main__': withinit_process(0,0,run,backend='mpi').

  2. Runmpirun-n4pythonmyscript.py.

The reason for these changes is that MPI needs to create its ownenvironment before spawning the processes. MPI will also spawn its ownprocesses and perform the handshake described inInitializationMethods, making therankandsizearguments ofinit_process_group superfluous. This is actually quitepowerful as you can pass additional arguments tompirun in order totailor computational resources for each process. (Things like number ofcores per process, hand-assigning machines to specific ranks, andsomemore)Doing so, you should obtain the same familiar output as with the othercommunication backends.

NCCL Backend

TheNCCL backend provides anoptimized implementation of collective operations against CUDAtensors. If you only use CUDA tensors for your collective operations,consider using this backend for the best in class performance. TheNCCL backend is included in the pre-built binaries with CUDA support.

XCCL Backend

TheXCCL backend offers an optimized implementation of collective operations for XPU tensors.If your workload uses only XPU tensors for collective operations,this backend provides best-in-class performance.The XCCL backend is included in the pre-built binaries with XPU support.

Initialization Methods#

To conclude this tutorial, let’s examine the initial function we invoked:dist.init_process_group(backend,init_method). Specifically, we will discuss the variousinitialization methods responsible for the preliminary coordination step between each process.These methods enable you to define how this coordination is accomplished.

The choice of initialization method depends on your hardware setup, and one method may be moresuitable than others. In addition to the following sections, please refer to theofficialdocumentation for further information.

Environment Variable

We have been using the environment variable initialization methodthroughout this tutorial. By setting the following four environmentvariables on all machines, all processes will be able to properlyconnect to the master, obtain information about the other processes, andfinally handshake with them.

  • MASTER_PORT: A free port on the machine that will host theprocess with rank 0.

  • MASTER_ADDR: IP address of the machine that will host the processwith rank 0.

  • WORLD_SIZE: The total number of processes, so that the masterknows how many workers to wait for.

  • RANK: Rank of each process, so they will know whether it is themaster or a worker.

Shared File System

The shared filesystem requires all processes to have access to a sharedfile system, and will coordinate them through a shared file. This meansthat each process will open the file, write its information, and waituntil everybody did so. After that all required information will bereadily available to all processes. In order to avoid race conditions,the file system must support locking throughfcntl.

dist.init_process_group(init_method='file:///mnt/nfs/sharedfile',rank=args.rank,world_size=4)

TCP

Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number.Here, all workers will be able to connect to the processwith rank 0 and exchange information on how to reach each other.

dist.init_process_group(init_method='tcp://10.1.1.20:23456',rank=args.rank,world_size=4)

Acknowledgements

I’d like to thank the PyTorch developers for doing such a good job ontheir implementation, documentation, and tests. When the code wasunclear, I could always count on thedocs or theteststo find an answer. In particular, I’d like to thank Soumith Chintala,Adam Paszke, and Natalia Gimelshein for providing insightful commentsand answering questions on early drafts.