Distributed Data Parallel#
Created On: Jan 15, 2020 | Last Updated On: Jan 25, 2024
Warning
The implementation oftorch.nn.parallel.DistributedDataParallelevolves over time. This design note is written based on the state as of v1.4.
torch.nn.parallel.DistributedDataParallel (DDP) transparently performsdistributed data parallel training. This page describes how it works and revealsimplementation details.
Example#
Let us start with a simpletorch.nn.parallel.DistributedDataParallelexample. This example uses atorch.nn.Linear as the local model, wrapsit with DDP, and then runs one forward pass, one backward pass, and an optimizerstep on the DDP model. After that, parameters on the local model will beupdated, and all models on different processes should be exactly the same.
importtorchimporttorch.distributedasdistimporttorch.multiprocessingasmpimporttorch.nnasnnimporttorch.optimasoptimimportosfromtorch.nn.parallelimportDistributedDataParallelasDDPdefexample(rank,world_size):# create default process groupdist.init_process_group("gloo",rank=rank,world_size=world_size)# create local modelmodel=nn.Linear(10,10).to(rank)# construct DDP modelddp_model=DDP(model,device_ids=[rank])# define loss function and optimizerloss_fn=nn.MSELoss()optimizer=optim.SGD(ddp_model.parameters(),lr=0.001)# forward passoutputs=ddp_model(torch.randn(20,10).to(rank))labels=torch.randn(20,10).to(rank)# backward passloss_fn(outputs,labels).backward()# update parametersoptimizer.step()defmain():world_size=2mp.spawn(example,args=(world_size,),nprocs=world_size,join=True)if__name__=="__main__":# Environment variables which need to be# set when using c10d's default "env"# initialization mode.os.environ["MASTER_ADDR"]="localhost"os.environ["MASTER_PORT"]="29500"main()
DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapperbefore compiling the model, such that torchdynamo can applyDDPOptimizer(graph-break optimizations) based on DDP bucket sizes. (SeeTorchDynamo DDPOptimizer for more information.)
ddp_model=DDP(model,device_ids=[rank])ddp_model=torch.compile(ddp_model)
Internal Design#
This section reveals how it works under the hood oftorch.nn.parallel.DistributedDataParallel by diving into details ofevery step in one iteration.
Prerequisite: DDP relies on c10d
ProcessGroupfor communications.Hence, applications must createProcessGroupinstances before constructingDDP.Construction: The DDP constructor takes a reference to the local module,and broadcasts
state_dict()from the process with rank 0 to all otherprocesses in the group to make sure that all model replicas start from theexact same state. Then, each DDP process creates a localReducer, whichlater will take care of the gradients synchronization during the backwardpass. To improve communication efficiency, theReducerorganizes parametergradients into buckets, and reduces one bucket at a time. Bucket size can beconfigured by setting thebucket_cap_mb argument in DDP constructor. Themapping from parameter gradients to buckets is determined at the constructiontime, based on the bucket size limit and parameter sizes. Model parameters areallocated into buckets in (roughly) the reverse order ofModel.parameters()from the given model. The reason for using the reverseorder is because DDP expects gradients to become ready during the backwardpass in approximately that order. The figure below shows an example. Notethat, thegrad0andgrad1are inbucket1, and the other twogradients are inbucket0. Of course, this assumption might not alwaysbe true, and when that happens it could hurt DDP backward speed as theReducercannot kick off the communication at the earliest possible time.Besides bucketing, theReduceralso registers autograd hooks duringconstruction, one hook per parameter. These hooks will be triggered duringthe backward pass when the gradient becomes ready.Forward Pass: The DDP takes the input and passes it to the local model,and then analyzes the output from the local model if
find_unused_parametersis set toTrue. This mode allows runningbackward on a subgraph of the model, and DDP finds out which parameters areinvolved in the backward pass by traversing the autograd graph from the modeloutput and marking all unused parameters as ready for reduction. During thebackward pass, theReducerwould only wait for unready parameters, but itwould still reduce all buckets. Marking a parameter gradient as ready does nothelp DDP skip buckets as for now, but it will prevent DDP from waiting forabsent gradients forever during the backward pass. Note that traversing theautograd graph introduces extra overheads, so applications should only setfind_unused_parameterstoTruewhen necessary.Backward Pass: The
backward()function is directly invoked on the lossTensor, which is out of DDP’s control, and DDP uses autograd hooksregistered at construction time to trigger gradients synchronizations. Whenone gradient becomes ready, its corresponding DDP hook on that gradaccumulator will fire, and DDP will then mark that parameter gradient asready for reduction. When gradients in one bucket are all ready, theReducerkicks off an asynchronousallreduceon that bucket tocalculate mean of gradients across all processes. When all buckets are ready,theReducerwill block waiting for allallreduceoperations to finish.When this is done, averaged gradients are written to theparam.gradfieldof all parameters. So after the backward pass, thegrad field on the samecorresponding parameter across different DDP processes should be the same.Optimizer Step: From the optimizer’s perspective, it is optimizing a localmodel. Model replicas on all DDP processes can keep in sync because they allstart from the same state and they have the same averaged gradients inevery iteration.

Note
DDP requiresReducer instances on all processes to invokeallreducein exactly the same order, which is done by always runningallreducein the bucket index order instead of actual bucket ready order. Mismatchedallreduce order across processes can lead to wrong results or DDP backwardhang.
Implementation#
Below are pointers to the DDP implementation components. The stacked graph showsthe structure of the code.
ProcessGroup#
ProcessGroup.hpp:contains the abstract API of all process group implementations. The
c10dlibrary provides 3 implementations out of the box, namely,ProcessGroupGloo,ProcessGroupNCCL, andProcessGroupMPI.DistributedDataParallelusesProcessGroup::broadcast()to sendmodel states from the process with rank 0 to others during initializationandProcessGroup::allreduce()to sum gradients.Store.hpp:assists the rendezvous service for process group instances to find each other.
DistributedDataParallel#
distributed.py:is the Python entry point for DDP. It implements the initialization steps andthe
forwardfunction for thenn.parallel.DistributedDataParallelmodule which call into C++ libraries. Its_sync_paramfunction performsintra-process parameter synchronization when one DDP process works on multipledevices, and it also broadcasts model buffers from the process with rank 0 toall other processes. The inter-process parameter synchronization happens inReducer.cpp.comm.h:implements the coalesced broadcast helper function which is invoked tobroadcast model states during initialization and synchronize model buffersbefore the forward pass.
reducer.h:provides the core implementation for gradient synchronization in the backwardpass. It has three entry point functions:
Reducer: The constructor is called indistributed.pywhich registersReducer::autograd_hook()to gradient accumulators.autograd_hook()function will be invoked by the autograd engine whena gradient becomes ready.prepare_for_backward()is called at the end of DDP forward pass indistributed.py. It traverses the autograd graph to find unusedparameters whenfind_unused_parametersis set toTruein DDPconstructor.

TorchDynamo DDPOptimizer#
DDP’s performance advantage comes from overlapping allreduce collectives with computations during backwards.AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph,because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.
TorchDynamo’s DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP’s allreduce bucketsduring backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is tobreak the forward graphs and then call AotAutograd and compilation on each section. This allows DDP’s allreduce hooksto fire in-between sections of backwards, and schedule communications to overlap with compute.
Seethis blog post fora more in-depth explanation and experimental results, or read the docs and code attorch/_dynamo/optimizations/distributed.py
To Debug DDPOptimizer, setTORCH_LOGS=’ddp_graphs’ for full graph dumps. For logs without graphs, add any of ‘dynamo’, ‘distributed’, or ‘dist_ddp’ toTORCH_LOGS(for basic info about bucket boundaries). To disable DDPOptimizer, settorch._dynamo.config.optimize_ddp=False.DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.