Rate this Page

DistributedDataParallel#

classtorch.nn.parallel.DistributedDataParallel(module,device_ids=None,output_device=None,dim=0,broadcast_buffers=True,init_sync=True,process_group=None,bucket_cap_mb=None,find_unused_parameters=False,check_reduction=False,gradient_as_bucket_view=False,static_graph=False,delay_all_reduce_named_params=None,param_to_hook_all_reduce=None,mixed_precision=None,device_mesh=None,skip_all_reduce_unused_params=False)[source]#

Implement distributed data parallelism based ontorch.distributed at module level.

This container provides data parallelism by synchronizing gradientsacross each model replica. The devices to synchronize across arespecified by the inputprocess_group, which is the entire worldby default. Note thatDistributedDataParallel does not chunk orotherwise shard the input across participating GPUs; the user isresponsible for defining how to do so, for example through the useof aDistributedSampler.

See also:Basics andUse nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel.The same constraints on input as intorch.nn.DataParallel apply.

Creation of this class requires thattorch.distributed to be alreadyinitialized, by callingtorch.distributed.init_process_group().

DistributedDataParallel is proven to be significantly faster thantorch.nn.DataParallel for single-node multi-GPU dataparallel training.

To useDistributedDataParallel on a host with N GPUs, you should spawnupN processes, ensuring that each process exclusively works on a singleGPU from 0 to N-1. This can be done by either settingCUDA_VISIBLE_DEVICES for every process or by calling the following API for GPUs,

>>>torch.cuda.set_device(i)

or calling the unified API foraccelerator,

>>>torch.accelerator.set_device_index(i)

where i is from 0 to N-1. In each process, you should refer the followingto construct this module:

>>>iftorch.accelerator.is_available():>>>device_type=torch.accelerator.current_accelerator().type>>>vendor_backend=torch.distributed.get_default_backend_for_device(device_type)>>>>>>torch.distributed.init_process_group(>>>backend=vendor_backend,world_size=N,init_method='...'>>>)>>>model=DistributedDataParallel(model,device_ids=[i],output_device=i)

Or you can use the latest API for initialization:

>>>torch.distributed.init_process_group(device_id=i)

In order to spawn up multiple processes per node, you can use eithertorch.distributed.launch ortorch.multiprocessing.spawn.

Note

Please refer toPyTorch Distributed Overviewfor a brief introduction to all features related to distributed training.

Note

DistributedDataParallel can be used in conjunction withtorch.distributed.optim.ZeroRedundancyOptimizer to reduceper-rank optimizer states memory footprint. Please refer toZeroRedundancyOptimizer recipefor more details.

Note

nccl backend is currently the fastest and highly recommendedbackend when using GPUs. This applies to both single-node andmulti-node distributed training.

Note

This module also supports mixed-precision distributed training.This means that your model can have different types of parameters suchas mixed types offp16 andfp32, the gradient reduction on thesemixed types of parameters will just work fine.

Note

If you usetorch.save on one process to checkpoint the module,andtorch.load on some other processes to recover it, make sure thatmap_location is configured properly for every process. Withoutmap_location,torch.load would recover the module to deviceswhere the module was saved from.

Note

When a model is trained onM nodes withbatch=N, thegradient will beM times smaller when compared to the same modeltrained on a single node withbatch=M*N if the loss is summed (NOTaveraged as usual) across instances in a batch (because the gradientsbetween different nodes are averaged). You should take this intoconsideration when you want to obtain a mathematically equivalenttraining process compared to the local training counterpart. But in mostcases, you can just treat a DistributedDataParallel wrapped model, aDataParallel wrapped model and an ordinary model on a single GPU as thesame (E.g. using the same learning rate for equivalent batch size).

Note

Parameters are never broadcast between processes. The module performsan all-reduce step on gradients and assumes that they will be modifiedby the optimizer in all processes in the same way. Buffers(e.g. BatchNorm stats) are broadcast from the module in process of rank0, to all other replicas in the system in every iteration.

Note

If you are using DistributedDataParallel in conjunction with theDistributed RPC Framework, you should always usetorch.distributed.autograd.backward() to compute gradients andtorch.distributed.optim.DistributedOptimizer for optimizingparameters.

Example:

>>>importtorch.distributed.autogradasdist_autograd>>>fromtorch.nn.parallelimportDistributedDataParallelasDDP>>>importtorch>>>fromtorchimportoptim>>>fromtorch.distributed.optimimportDistributedOptimizer>>>importtorch.distributed.rpcasrpc>>>fromtorch.distributed.rpcimportRRef>>>>>>t1=torch.rand((3,3),requires_grad=True)>>>t2=torch.rand((3,3),requires_grad=True)>>>rref=rpc.remote("worker1",torch.add,args=(t1,t2))>>>ddp_model=DDP(my_model)>>>>>># Setup optimizer>>>optimizer_params=[rref]>>>forparaminddp_model.parameters():>>>optimizer_params.append(RRef(param))>>>>>>dist_optim=DistributedOptimizer(>>>optim.SGD,>>>optimizer_params,>>>lr=0.05,>>>)>>>>>>withdist_autograd.context()ascontext_id:>>>pred=ddp_model(rref.to_here())>>>loss=loss_func(pred,target)>>>dist_autograd.backward(context_id,[loss])>>>dist_optim.step(context_id)

Note

DistributedDataParallel currently offers limited support for gradientcheckpointing withtorch.utils.checkpoint().If the checkpoint is done with use_reentrant=False (recommended), DDPwill work as expected without any limitations.If, however, the checkpoint is done with use_reentrant=True (the default),DDP will work as expected when there are no unused parameters in the modeland each layer is checkpointed at most once (make sure you are not passingfind_unused_parameters=True to DDP). We currently do not support thecase where a layer is checkpointed multiple times, or when there unusedparameters in the checkpointed model.

Note

To let a non-DDP model load a state dict from a DDP model,consume_prefix_in_state_dict_if_present()needs to be applied to strip the prefix “module.” in the DDP state dict before loading.

Warning

Constructor, forward method, and differentiation of the output (or afunction of the output of this module) are distributed synchronizationpoints. Take that into account in case different processes might beexecuting different code.

Warning

This module assumes all parameters are registered in the model by thetime it is created. No parameters should be added nor removed later.Same applies to buffers.

Warning

This module assumes all parameters are registered in the model of eachdistributed processes are in the same order. The module itself willconduct gradientallreduce following the reverse order of theregistered parameters of the model. In other words, it is users’responsibility to ensure that each distributed process has the exactsame model and thus the exact same parameter registration order.

Warning

This module allows parameters with non-rowmajor-contiguous strides.For example, your model may contain some parameters whosetorch.memory_format istorch.contiguous_formatand others whose format istorch.channels_last. However,corresponding parameters in different processes must have thesame strides.

Warning

This module doesn’t work withtorch.autograd.grad() (i.e. it willonly work if gradients are to be accumulated in.grad attributes ofparameters).

Warning

If you plan on using this module with anccl backend or agloobackend (that uses Infiniband), together with a DataLoader that usesmultiple workers, please change the multiprocessing start method toforkserver (Python 3 only) orspawn. UnfortunatelyGloo (that uses Infiniband) and NCCL2 are not fork safe, and you willlikely experience deadlocks if you don’t change this setting.

Warning

You should never try to change your model’s parameters after wrappingup your model withDistributedDataParallel. Because, whenwrapping up your model withDistributedDataParallel, the constructorofDistributedDataParallel will register the additional gradientreduction functions on all the parameters of the model itself at thetime of construction. If you change the model’s parameters afterwards,gradient reduction functions no longer match the correct set ofparameters.

Warning

UsingDistributedDataParallel in conjunction with theDistributed RPC Framework is experimental and subject to change.

Parameters
  • module (Module) – module to be parallelized

  • device_ids (list ofint ortorch.device) –

    CUDA devices.1) For single-device modules,device_ids cancontain exactly one device id, which represents the onlyCUDA device where the input module corresponding to this process resides.Alternatively,device_ids can also beNone.2) For multi-device modules and CPU modules,device_ids must beNone.

    Whendevice_ids isNone for both cases,both the input data for the forward pass and the actual modulemust be placed on the correct device.(default:None)

  • output_device (int ortorch.device) – Device location of output forsingle-device CUDA modules. For multi-device modules andCPU modules, it must beNone, and the module itselfdictates the output location. (default:device_ids[0]for single-device modules)

  • broadcast_buffers (bool) – Flag that enables syncing (broadcasting)buffers of the module at beginning of theforwardfunction. (default:True)

  • init_sync (bool) – Whether to sync during initialization to verify paramshapes and broadcast parameters and buffers.WARNING: if this is set to False the user is requiredto ensure themselves that the weights are the same onall ranks.(default:True)

  • process_group – The process group to be used for distributed dataall-reduction. IfNone, the default process group, whichis created bytorch.distributed.init_process_group(),will be used. (default:None)

  • bucket_cap_mbDistributedDataParallel will bucket parameters intomultiple buckets so that gradient reduction of eachbucket can potentially overlap with backward computation.bucket_cap_mb controls the bucket size inMebiBytes (MiB). IfNone, a default size of 25 MiBwill be used. (default:None)

  • find_unused_parameters (bool) – Traverse the autograd graph from alltensors contained in the return value of thewrapped module’sforward function. Parametersthat don’t receive gradients as part of thisgraph are preemptively marked as being ready tobe reduced. In addition, parameters that may havebeen used in the wrapped module’sforwardfunction but were not part of loss computation andthus would also not receive gradients arepreemptively marked as ready to be reduced.(default:False)

  • check_reduction – This argument is deprecated.

  • gradient_as_bucket_view (bool) – When set toTrue, gradients will be viewspointing to different offsets ofallreduce communicationbuckets. This can reduce peak memory usage, where thesaved memory size will be equal to the total gradientssize. Moreover, it avoids the overhead of copying betweengradients andallreduce communication buckets. Whengradients are views,detach_() cannot be called on thegradients. If hitting such errors, please fix it byreferring to thezero_grad()function intorch/optim/optimizer.py as a solution.Note that gradients will be views after first iteration, sothe peak memory saving should be checked after first iteration.

  • static_graph (bool) –

    When set toTrue, DDP knows the trained graph isstatic. Static graph means 1) The set of used and unusedparameters will not change during the whole training loop; inthis case, it does not matter whether users setfind_unused_parameters=True or not. 2) How the graph is trainedwill not change during the whole training loop (meaning there isno control flow depending on iterations).When static_graph is set to beTrue, DDP will support cases thatcan not be supported in the past:1) Reentrant backwards.2) Activation checkpointing multiple times.3) Activation checkpointing when model has unused parameters.4) There are model parameters that are outside of forward function.5) Potentially improve performance when there are unused parameters,as DDP will not search graph in each iteration to detect unusedparameters when static_graph is set to beTrue.To check whether you can set static_graph to beTrue, one way is tocheck ddp logging data at the end of your previous model training,ifddp_logging_data.get("can_set_static_graph")==True, mostly youcan setstatic_graph=True as well.

    Example::
    >>>model_DDP=torch.nn.parallel.DistributedDataParallel(model)>>># Training loop>>>...>>>ddp_logging_data=model_DDP._get_ddp_logging_data()>>>static_graph=ddp_logging_data.get("can_set_static_graph")

  • delay_all_reduce_named_params (list oftuple ofstr and torch.nn.Parameter) – a listof named parameters whose all reduce will be delayed when the gradient ofthe parameter specified inparam_to_hook_all_reduce is ready. Otherarguments of DDP do not apply to named params specified in this argumentas these named params will be ignored by DDP reducer.

  • param_to_hook_all_reduce (torch.nn.Parameter) – a parameter to hook delayed all reduceof parameters specified indelay_all_reduce_named_params.

  • skip_all_reduce_unused_params – When set to True, DDP will skip reducing unused parameters.This requires that unused parameters remain the same across all ranks throughoutthe entire training process. If this condition is not met, it may causedesynchronization and result in training hang.

Variables

module (Module) – the module to be parallelized.

Example:

>>>torch.distributed.init_process_group(backend='nccl',world_size=4,init_method='...')>>>net=torch.nn.parallel.DistributedDataParallel(model)
join(divide_by_initial_world_size=True,enable=True,throw_on_early_termination=False)[source]#

Context manager for training with uneven inputs across processes in DDP.

This context manager will keep track of already-joined DDP processes,and “shadow” the forward and backward passes by inserting collectivecommunication operations to match with the ones created by non-joinedDDP processes. This will ensure each collective call has a correspondingcall by already-joined DDP processes, preventing hangs or errors thatwould otherwise happen when training with uneven inputs acrossprocesses. Alternatively, if the flagthrow_on_early_termination isspecified to beTrue, all trainers will throw an error once one rankruns out of inputs, allowing these errors to be caught and handledaccording to application logic.

Once all DDP processes have joined, the context manager will broadcastthe model corresponding to the last joined process to all processes toensure the model is the same across all processes(which is guaranteed by DDP).

To use this to enable training with uneven inputs across processes,simply wrap this context manager around your training loop. No furthermodifications to the model or data loading is required.

Warning

If the model or training loop this context manager is wrapped aroundhas additional distributed collective operations, such asSyncBatchNorm in the model’s forward pass, then the flagthrow_on_early_termination must be enabled. This is because thiscontext manager is not aware of non-DDP collective communication.This flag will cause all ranks to throw when any one rankexhausts inputs, allowing these errors to be caught and recoveredfrom across all ranks.

Parameters
  • divide_by_initial_world_size (bool) – IfTrue, will dividegradients by the initialworld_size DDP training was launchedwith. IfFalse, will compute the effective world size(number of ranks that have not depleted their inputs yet) anddivide gradients by that during allreduce. Setdivide_by_initial_world_size=True to ensure every inputsample including the uneven inputs have equal weight in terms ofhow much they contribute to the global gradient. This isachieved by always dividing the gradient by the initialworld_size even when we encounter uneven inputs. If you setthis toFalse, we divide the gradient by the remainingnumber of nodes. This ensures parity with training on a smallerworld_size although it also means the uneven inputs wouldcontribute more towards the global gradient. Typically, youwould want to set this toTrue for cases where the last fewinputs of your training job are uneven. In extreme cases, wherethere is a large discrepancy in the number of inputs, settingthis toFalse might provide better results.

  • enable (bool) – Whether to enable uneven input detection or not. Passinenable=False to disable in cases where you know thatinputs are even across participating processes. Default isTrue.

  • throw_on_early_termination (bool) – Whether to throw an erroror continue training when at least one rank has exhaustedinputs. IfTrue, will throw upon the first rank reaching endof data. IfFalse, will continue training with a smallereffective world size until all ranks are joined. Note that ifthis flag is specified, then the flagdivide_by_initial_world_size would be ignored. DefaultisFalse.

Example:

>>>importtorch>>>importtorch.distributedasdist>>>importos>>>importtorch.multiprocessingasmp>>>importtorch.nnasnn>>># On each spawned worker>>>defworker(rank):>>>dist.init_process_group("nccl",rank=rank,world_size=2)>>>torch.cuda.set_device(rank)>>>model=nn.Linear(1,1,bias=False).to(rank)>>>model=torch.nn.parallel.DistributedDataParallel(>>>model,device_ids=[rank],output_device=rank>>>)>>># Rank 1 gets one more input than rank 0.>>>inputs=[torch.tensor([1]).float()for_inrange(10+rank)]>>>withmodel.join():>>>for_inrange(5):>>>forinpininputs:>>>loss=model(inp).sum()>>>loss.backward()>>># Without the join() API, the below synchronization will hang>>># blocking for rank 1's allreduce to complete.>>>torch.cuda.synchronize(device=rank)
join_hook(**kwargs)[source]#

DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.

Parameters

kwargs (dict) – adict containing any keyword argumentsto modify the behavior of the join hook at run time; allJoinable instances sharing the same join contextmanager are forwarded the same value forkwargs.

The hook supports the following keyword arguments:
divide_by_initial_world_size (bool, optional):

IfTrue, then gradients are divided by the initial worldsize that DDP was launched with.IfFalse, then gradients are divided by the effective worldsize (i.e. the number of non-joined processes), meaning thatthe uneven inputs contribute more toward the global gradient.Typically, this should be set toTrue if the degree ofunevenness is small but can be set toFalse in extremecases for possibly better results.Default isTrue.

no_sync()[source]#

Context manager to disable gradient synchronizations across DDP processes.

Within this context, gradients will be accumulated on modulevariables, which will later be synchronized in the firstforward-backward pass exiting the context.

Example:

>>>ddp=torch.nn.parallel.DistributedDataParallel(model,pg)>>>withddp.no_sync():>>>forinputininputs:>>>ddp(input).backward()# no synchronization, accumulate grads>>>ddp(another_input).backward()# synchronize grads

Warning

The forward pass should be included inside the context manager, orelse gradients will still be synchronized.

register_comm_hook(state,hook)[source]#

Register communication hook for user-defined DDP aggregation of gradients across multiple workers.

This hook would be very useful for researchers to try out new ideas. Forexample, this hook can be used to implement several algorithms like GossipGradand gradient compression which involve different communication strategies forparameter syncs while running Distributed DataParallel training.

Parameters
  • state (object) –

    Passed to the hook to maintain any state information during the training process.Examples include error feedback in gradient compression,peers to communicate with next in GossipGrad, etc.

    It is locally stored by each workerand shared by all the gradient tensors on the worker.

  • hook (Callable) –

    Callable with the following signature:hook(state:object,bucket:dist.GradBucket)->torch.futures.Future[torch.Tensor]:

    This function is called once the bucket is ready. Thehook can perform whatever processing is needed and returna Future indicating completion of any async work (ex: allreduce).If the hook doesn’t perform any communication, it stillmust return a completed Future. The Future should hold thenew value of grad bucket’s tensors. Once a bucket is ready,c10d reducer would call this hook and use the tensors returnedby the Future and copy grads to individual parameters.Note that the future’s return type must be a single tensor.

    We also provide an API calledget_future to retrieve aFuture associated with the completion ofc10d.ProcessGroup.Work.get_future is currently supported for NCCL and also supported for mostoperations on GLOO and MPI, except for peer to peer operations (send/recv).

Warning

Grad bucket’s tensors will not be predivided by world_size. User is responsibleto divide by the world_size in case of operations like allreduce.

Warning

DDP communication hook can only be registered once and should be registeredbefore calling backward.

Warning

The Future object that hook returns should contain a single tensorthat has the same shape with the tensors inside grad bucket.

Warning

get_future API supports NCCL, and partially GLOO and MPI backends (no supportfor peer-to-peer operations like send/recv) and will return atorch.futures.Future.

Example::

Below is an example of a noop hook that returns the same tensor.

>>>defnoop(state:object,bucket:dist.GradBucket)->torch.futures.Future[torch.Tensor]:>>>fut=torch.futures.Future()>>>fut.set_result(bucket.buffer())>>>returnfut>>>ddp.register_comm_hook(state=None,hook=noop)
Example::

Below is an example of a Parallel SGD algorithm where gradients are encoded beforeallreduce, and then decoded after allreduce.

>>>defencode_and_decode(state:object,bucket:dist.GradBucket)->torch.futures.Future[torch.Tensor]:>>>encoded_tensor=encode(bucket.buffer())# encode gradients>>>fut=torch.distributed.all_reduce(encoded_tensor).get_future()>>># Define the then callback to decode.>>>defdecode(fut):>>>decoded_tensor=decode(fut.value()[0])# decode gradients>>>returndecoded_tensor>>>returnfut.then(decode)>>>ddp.register_comm_hook(state=None,hook=encode_and_decode)