Distributed Autograd Design#
Created On: Nov 12, 2019 | Last Updated On: Sep 03, 2021
This note will present the detailed design for distributed autograd and walkthrough the internals of the same. Make sure you’re familiar withAutograd mechanics and theDistributed RPC Framework beforeproceeding.
Background#
Let’s say you have two nodes and a very simple model partitioned across twonodes. This can be implemented usingtorch.distributed.rpc as follows:
importtorchimporttorch.distributed.rpcasrpcdefmy_add(t1,t2):returntorch.add(t1,t2)# On worker 0:t1=torch.rand((3,3),requires_grad=True)t2=torch.rand((3,3),requires_grad=True)# Perform some computation remotely.t3=rpc.rpc_sync("worker1",my_add,args=(t1,t2))# Perform some computation locally based on remote result.t4=torch.rand((3,3),requires_grad=True)t5=torch.mul(t3,t4)# Compute some loss.loss=t5.sum()
The main motivation behind distributed autograd is to enable running a backwardpass on such distributed models with theloss that we’ve computed andrecord appropriate gradients for all tensors that require gradients.
Autograd recording during the forward pass#
PyTorch builds the autograd graph during the forward pass and this graph isused to execute the backward pass. For more details seeHow autograd encodes the history.
For distributed autograd, we need to keep track of all RPCs during the forwardpass to ensure the backward pass is executed appropriately. For this purpose,we attachsend andrecv functions to the autograd graph when we performan RPC.
The
sendfunction is attached to the source of the RPC and its outputedges point to the autograd function for the input tensors of the RPC.The input for this function during the backward pass is received from thedestination as the output of the appropriaterecvfunction.The
recvfunction is attached to the destination of the RPC and itsinputs are retrieved from operators executed on the destination using theinput tensors. The output gradients of this function are sent to the sourcenode to the appropriatesendfunction during the backward pass.Each
send-recvpair is assigned a globally uniqueautograd_message_idto uniquely identify the pair. This is useful to look up the correspondingfunction on a remote node during the backward pass.ForRRef, whenever we call
torch.distributed.rpc.RRef.to_here()we attach an appropriatesend-recvpair for the tensors involved.
As an example, this is what the autograd graph for our example above would looklike (t5.sum() excluded for simplicity):

Distributed Autograd Context#
Each forward and backward pass that uses distributed autograd is assigned auniquetorch.distributed.autograd.context and this context has aglobally uniqueautograd_context_id. This context is created on each nodeas needed.
This context serves the following purpose:
Multiple nodes running distributed backward passes might accumulategradients on the same tensor and as a result the
.gradfield of thetensor would have gradients from a variety of distributed backward passesbefore we have the opportunity to run the optimizer. This is similar tocallingtorch.autograd.backward()multiple times locally. In order toprovide a way of separating out the gradients for each backward pass, thegradients are accumulated in thetorch.distributed.autograd.contextfor each backward pass.During the forward pass we store the
sendandrecvfunctions foreach autograd pass in this context. This ensures we hold references to theappropriate nodes in the autograd graph to keep it alive. In addition tothis, it is easy to look up the appropriatesendandrecvfunctionsduring the backward pass.In general we also use this context to store some metadata for eachdistributed autograd pass.
From the user’s perspective the autograd context is setup as follows:
importtorch.distributed.autogradasdist_autogradwithdist_autograd.context()ascontext_id:loss=model.forward()dist_autograd.backward(context_id,loss)
It is important to note that your model’s forward pass must be invoked withinthe distributed autograd context manager, as a valid context is needed inorder to ensure that allsend andrecv functions are stored properlyto run the backward pass across all participating nodes.
Distributed Backward Pass#
In this section we outline the challenge of computing dependencies accuratelyduring a distributed backward pass and describe a couple of algorithms (withtradeoffs) on how we can execute a distributed backward pass.
Computing dependencies#
Consider the following piece of code being run on a single machine
importtorcha=torch.rand((3,3),requires_grad=True)b=torch.rand((3,3),requires_grad=True)c=torch.rand((3,3),requires_grad=True)d=a+be=b*cd.sum.().backward()
This is what the autograd graph for the code above would look like:

The first step the autograd engine performs as part of the backward pass iscomputing the number of dependencies for each node in the autograd graph. Thishelps the autograd engine know when a node in the graph is ready for execution.The numbers in brackets foradd(1) andmul(0) denote the number ofdependencies. As you can see, this means during the backward pass theaddnode needs 1 input and themul node doesn’t need any inputs (in otherwords doesn’t need to be executed). The local autograd engine computes thesedependencies by traversing the graph from the root nodes (d in this case).
The fact that certain nodes in the autograd graph might not be executed in thebackward pass poses a challenge for distributed autograd. Consider this pieceof code which uses RPC.
importtorchimporttorch.distributed.rpcasrpca=torch.rand((3,3),requires_grad=True)b=torch.rand((3,3),requires_grad=True)c=torch.rand((3,3),requires_grad=True)d=rpc.rpc_sync("worker1",torch.add,args=(a,b))e=rpc.rpc_sync("worker1",torch.mul,args=(b,c))loss=d.sum()
The associated autograd graph for the code above would be:

Computing dependencies of this distributed autograd graph is much morechallenging and requires some overhead (either in terms of computation ornetwork communication).
For performance sensitive applications we can avoid alot of overhead by assuming everysend andrecv function are valid aspart of the backward pass (most applications don’t perform RPCs that aren’tused). This simplifies the distributed autograd algorithm and is much moreefficient, but at the cost that the application needs to be aware of thelimitations. This algorithm is called theFAST mode algorithm and isdescribed in detail below.
In the general case it might not be necessary that everysend andrecvfunction is valid as part of the backward pass. To address this, we haveproposed aSMART mode algorithm which is described in a later section.Please note that currently, only theFAST mode algorithm is implemented.
FAST mode algorithm#
The key assumption of this algorithm is that eachsend function has adependency of 1 when we run a backward pass. In other words, we assume we’llreceive a gradient over RPC from another node.
The algorithm is as follows:
We start from the worker which has the roots for the backward pass(all roots must be local).
Lookup all the
sendfunctions for the currentDistributed Autograd Context.Compute dependencies locally starting from the provided roots and all the
sendfunctions we retrieved.After computing dependencies, kick off the local autograd engine with theprovided roots.
When the autograd engine executes the
recvfunction, therecvfunction sends the input gradients via RPC to the appropriate worker.Eachrecvfunction knows the destination worker id since it is recordedas part of the forward pass. Therecvfunction also sends over theautograd_context_idandautograd_message_idto the remote host.When this request is received on the remote host, we use the
autograd_context_idandautograd_message_idto look up theappropriatesendfunction.If this is the first time a worker has received a request for the given
autograd_context_id, it will compute dependencies locally as describedin points 1-3 above.The
sendfunction retrieved in 6. is then enqueued for execution on thelocal autograd engine for that worker.Finally, instead of accumulating the gradients on the
.gradfield of theTensor, we accumulate the gradients separately perDistributed Autograd Context. The gradients are stored in aDict[Tensor,Tensor], which is basically a map from Tensor to itsassociated gradient and this map can be retrieved using theget_gradients()API.
As an example the complete code with distributed autograd would be as follows:
importtorchimporttorch.distributed.autogradasdist_autogradimporttorch.distributed.rpcasrpcdefmy_add(t1,t2):returntorch.add(t1,t2)# On worker 0:# Setup the autograd context. Computations that take# part in the distributed backward pass must be within# the distributed autograd context manager.withdist_autograd.context()ascontext_id:t1=torch.rand((3,3),requires_grad=True)t2=torch.rand((3,3),requires_grad=True)# Perform some computation remotely.t3=rpc.rpc_sync("worker1",my_add,args=(t1,t2))# Perform some computation locally based on remote result.t4=torch.rand((3,3),requires_grad=True)t5=torch.mul(t3,t4)# Compute some loss.loss=t5.sum()# Run the backward pass.dist_autograd.backward(context_id,[loss])# Retrieve the gradients from the context.dist_autograd.get_gradients(context_id)
The distributed autograd graph with dependencies would be as follows (t5.sum() excluded for simplicity):

TheFAST mode algorithm applied to the above example would be as follows:
On
Worker0we start from the rootslossandsend1to computedependencies. As a resultsend1is marked with a dependency of 1 andmulonWorker0is marked with a dependency of 1.Now, we kickoff the local autograd engine on
Worker0. We first executethemulfunction, accumulate its output in the autograd context as thegradient fort4. Then, we executerecv2which sends the gradients toWorker1.Since this is the first time
Worker1has heard about this backward pass,it starts dependency computation and marks the dependencies forsend2,addandrecv1appropriately.Next, we enqueue
send2on the local autograd engine ofWorker1, whichin turn executesaddandrecv1.When
recv1is executed it sends the gradients over toWorker0.Since
Worker0has already computed dependencies for this backward pass,it just enqueues and executessend1locally.Finally, gradients for
t1,t2andt4are accumulated in theDistributed Autograd Context.
SMART mode algorithm#
Full details of this algorithm are still in the works, but for the general ideayou can refer toDistributed Autograd Algorithm Smart mode section in theRFC.
Distributed Optimizer#
TheDistributedOptimizer operates as follows:
Takes a list of remote parameters (
RRef) tooptimize. These could also be local parameters wrapped within a localRRef.Takes a
Optimizerclass as the localoptimizer to run on all distinctRRefowners.The distributed optimizer creates an instance of the local
Optimizeroneach of the worker nodes and holds anRRefto them.When
torch.distributed.optim.DistributedOptimizer.step()is invoked,the distributed optimizer uses RPC to remotely execute all the localoptimizers on the appropriate remote workers. A distributed autogradcontext_idmust be provided as input totorch.distributed.optim.DistributedOptimizer.step(). This is usedby local optimizers to apply gradients stored in the correspondingcontext.If multiple concurrent distributed optimizers are updating the sameparameters on a worker, these updates are serialized via a lock.
Simple end to end example#
Putting it all together, the following is a simple end to end example usingdistributed autograd and the distributed optimizer. If the code is placed into afile called “dist_autograd_simple.py”, it can be run with the commandMASTER_ADDR="localhost"MASTER_PORT=29500pythondist_autograd_simple.py:
importtorchimporttorch.multiprocessingasmpimporttorch.distributed.autogradasdist_autogradfromtorch.distributedimportrpcfromtorchimportoptimfromtorch.distributed.optimimportDistributedOptimizerdefrandom_tensor():returntorch.rand((3,3),requires_grad=True)def_run_process(rank,dst_rank,world_size):name="worker{}".format(rank)dst_name="worker{}".format(dst_rank)# Initialize RPC.rpc.init_rpc(name=name,rank=rank,world_size=world_size)# Use a distributed autograd context.withdist_autograd.context()ascontext_id:# Forward pass (create references on remote nodes).rref1=rpc.remote(dst_name,random_tensor)rref2=rpc.remote(dst_name,random_tensor)loss=rref1.to_here()+rref2.to_here()# Backward pass (run distributed autograd).dist_autograd.backward(context_id,[loss.sum()])# Build DistributedOptimizer.dist_optim=DistributedOptimizer(optim.SGD,[rref1,rref2],lr=0.05,)# Run the distributed optimizer step.dist_optim.step(context_id)defrun_process(rank,world_size):dst_rank=(rank+1)%world_size_run_process(rank,dst_rank,world_size)rpc.shutdown()if__name__=='__main__':# Run world_size workersworld_size=2mp.spawn(run_process,args=(world_size,),nprocs=world_size)