Rate this Page

Distributed Optimizers#

Created On: Mar 01, 2021 | Last Updated On: Jun 16, 2025

Warning

Distributed optimizer is not currently supported when using CUDA tensors

torch.distributed.optim exposes DistributedOptimizer, which takes a listof remote parameters (RRef) and runs theoptimizer locally on the workers where the parameters live. The distributedoptimizer can use any of the local optimizerBase class toapply the gradients on each worker.

classtorch.distributed.optim.DistributedOptimizer(optimizer_class,params_rref,*args,**kwargs)[source]#

DistributedOptimizer takes remote references to parameters scatteredacross workers and applies the given optimizer locally for each parameter.

This class usesget_gradients() in orderto retrieve the gradients for specific parameters.

Concurrent calls tostep(),either from the same or different clients, willbe serialized on each worker – as each worker’s optimizer can only workon one set of gradients at a time. However, there is no guarantee thatthe full forward-backward-optimizer sequence will execute for one clientat a time. This means that the gradients being applied may not correspondto the latest forward pass executed on a given worker. Also, there is noguaranteed ordering across workers.

DistributedOptimizer creates the local optimizer with TorchScript enabledby default, so that optimizer updates are not blocked by the Python GlobalInterpreter Lock (GIL) in the case of multithreaded training (e.g. DistributedModel Parallel). This feature is currently enabled for most optimizers. Youcan also followthe recipe in PyTorch tutorials to enable TorchScript supportfor your own custom optimizers.

Parameters
  • optimizer_class (optim.Optimizer) – the class of optimizer toinstantiate on each worker.

  • params_rref (list[RRef]) – list of RRefs to local or remote parametersto optimize.

  • args – arguments to pass to the optimizer constructor on each worker.

  • kwargs – arguments to pass to the optimizer constructor on each worker.

Example::
>>>importtorch.distributed.autogradasdist_autograd>>>importtorch.distributed.rpcasrpc>>>fromtorchimportoptim>>>fromtorch.distributed.optimimportDistributedOptimizer>>>>>>withdist_autograd.context()ascontext_id:>>># Forward pass.>>>rref1=rpc.remote("worker1",torch.add,args=(torch.ones(2),3))>>>rref2=rpc.remote("worker1",torch.add,args=(torch.ones(2),1))>>>loss=rref1.to_here()+rref2.to_here()>>>>>># Backward pass.>>>dist_autograd.backward(context_id,[loss.sum()])>>>>>># Optimizer.>>>dist_optim=DistributedOptimizer(>>>optim.SGD,>>>[rref1,rref2],>>>lr=0.05,>>>)>>>dist_optim.step(context_id)
step(context_id)[source]#

Performs a single optimization step.

This will calltorch.optim.Optimizer.step() on each workercontaining parameters to be optimized, and will block until all workersreturn. The providedcontext_id will be used to retrieve thecorrespondingcontext thatcontains the gradients that should be applied to the parameters.

Parameters

context_id – the autograd context id for which we should run theoptimizer step.

classtorch.distributed.optim.PostLocalSGDOptimizer(optim,averager)[source]#

Wraps an arbitrarytorch.optim.Optimizer and runspost-local SGD,This optimizer runs local optimizer at every step.After the warm-up stage, it averages parameters periodically after the local optimizer is applied.

Parameters
  • optim (Optimizer) – The local optimizer.

  • averager (ModelAverager) – A model averager instance to run post-localSGD algorithm.

Example:

>>>importtorch>>>importtorch.distributedasdist>>>importtorch.distributed.algorithms.model_averaging.averagersasaveragers>>>importtorch.nnasnn>>>fromtorch.distributed.optimimportPostLocalSGDOptimizer>>>fromtorch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hookimport(>>>PostLocalSGDState,>>>post_localSGD_hook,>>>)>>>>>>model=nn.parallel.DistributedDataParallel(>>>module,device_ids=[rank],output_device=rank>>>)>>>>>># Register a post-localSGD communication hook.>>>state=PostLocalSGDState(process_group=None,subgroup=None,start_localSGD_iter=100)>>>model.register_comm_hook(state,post_localSGD_hook)>>>>>># Create a post-localSGD optimizer that wraps a local optimizer.>>># Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as>>># ``start_localSGD_iter`` used in ``PostLocalSGDState``.>>>local_optim=torch.optim.SGD(params=model.parameters(),lr=0.01)>>>opt=PostLocalSGDOptimizer(>>>optim=local_optim,>>>averager=averagers.PeriodicModelAverager(period=4,warmup_steps=100)>>>)>>>>>># In the first 100 steps, DDP runs global gradient averaging at every step.>>># After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),>>># and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.>>>forstepinrange(0,200):>>>opt.zero_grad()>>>loss=loss_fn(output,labels)>>>loss.backward()>>>opt.step()
load_state_dict(state_dict)[source]#

This is the same astorch.optim.Optimizerload_state_dict(),but also restores model averager’s step value to the onesaved in the providedstate_dict.

If there is no"step" entry instate_dict,it will raise a warning and initialize the model averager’s step to 0.

state_dict()[source]#

This is the same astorch.optim.Optimizerstate_dict(),but adds an extra entry to record model averager’s step to the checkpointto ensure reload does not cause unnecessary warm up again.

step()[source]#

Performs a single optimization step (parameter update).

classtorch.distributed.optim.ZeroRedundancyOptimizer(params,optimizer_class,process_group=None,parameters_as_bucket_view=False,overlap_with_ddp=False,**defaults)[source]#

Wrap an arbitraryoptim.Optimizer and shards its states across ranks in the group.

The sharing is done as described byZeRO.

The local optimizer instance in each rank is onlyresponsible for updating approximately1/world_size parameters andhence only needs to keep1/world_size optimizer states. Afterparameters are updated locally, each rank will broadcast its parameters toall other peers to keep all model replicas in the same state.ZeroRedundancyOptimizer can be used in conjunction withtorch.nn.parallel.DistributedDataParallel to reduce per-rank peakmemory consumption.

ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a numberof parameters at each rank. Each parameter belongs to a single rank and isnot divided among ranks. The partition is arbitrary and might not match thethe parameter registration or usage order.

Parameters

params (Iterable) – anIterable oftorch.Tensor sordict s giving all parameters, which will be shardedacross ranks.

Keyword Arguments
  • optimizer_class (torch.nn.Optimizer) – the class of the localoptimizer.

  • process_group (ProcessGroup, optional) –torch.distributedProcessGroup (default:dist.group.WORLD initialized bytorch.distributed.init_process_group()).

  • parameters_as_bucket_view (bool,optional) – ifTrue, parameters arepacked into buckets to speed up communication, andparam.datafields point to bucket views at different offsets; ifFalse,each individual parameter is communicated separately, and eachparams.data stays intact (default:False).

  • overlap_with_ddp (bool,optional) – ifTrue,step() isoverlapped withDistributedDataParallel ‘s gradientsynchronization; this requires (1) either a functional optimizerfor theoptimizer_class argument or one with a functionalequivalent and (2) registering a DDP communication hookconstructed from one of the functions inddp_zero_hook.py;parameters are packed into buckets matching those inDistributedDataParallel, meaning that theparameters_as_bucket_view argument is ignored.IfFalse,step() runs disjointly after the backward pass(per normal).(default:False)

  • **defaults – any trailing arguments, which are forwarded to the localoptimizer.

Example:

>>>importtorch.nnasnn>>>fromtorch.distributed.optimimportZeroRedundancyOptimizer>>>fromtorch.nn.parallelimportDistributedDataParallelasDDP>>>model=nn.Sequential(*[nn.Linear(2000,2000).to(rank)for_inrange(20)])>>>ddp=DDP(model,device_ids=[rank])>>>opt=ZeroRedundancyOptimizer(>>>ddp.parameters(),>>>optimizer_class=torch.optim.Adam,>>>lr=0.01>>>)>>>ddp(inputs).sum().backward()>>>opt.step()

Warning

Currently,ZeroRedundancyOptimizer requires that all of thepassed-in parameters are the same dense type.

Warning

If you passoverlap_with_ddp=True, be wary of the following: Giventhe way that overlappingDistributedDataParallel withZeroRedundancyOptimizer is currently implemented, the firsttwo or three training iterations do not perform parameter updates inthe optimizer step, depending on ifstatic_graph=False orstatic_graph=True, respectively. This is because it needsinformation about the gradient bucketing strategy used byDistributedDataParallel, which is not finalized until thesecond forward pass ifstatic_graph=False or until the thirdforward pass ifstatic_graph=True. To adjust for this, one optionis to prepend dummy inputs.

Warning

ZeroRedundancyOptimizer is experimental and subject to change.

add_param_group(param_group)[source]#

Add a parameter group to theOptimizer ‘sparam_groups.

This can be useful when fine tuning a pre-trained network, as frozenlayers can be made trainable and added to theOptimizer astraining progresses.

Parameters

param_group (dict) – specifies the parameters to be optimized andgroup-specific optimization options.

Warning

This method handles updating the shards on all partitionsbut needs to be called on all ranks. Calling this on a subset ofthe ranks will cause the training to hang because communicationprimitives are called depending on the managed parameters andexpect all the ranks to participate on the same set of parameters.

consolidate_state_dict(to=0)[source]#

Consolidate a list ofstate_dict s (one per rank) on the target rank.

Parameters

to (int) – the rank that receives the optimizer states (default: 0).

Raises

RuntimeError – ifoverlap_with_ddp=True and this method is called before thisZeroRedundancyOptimizer instance has been fully initialized, which happens onceDistributedDataParallel gradient buckets have been rebuilt.

Warning

This needs to be called on all ranks.

propertyjoin_device:device#

Return default device.

join_hook(**kwargs)[source]#

Return the ZeRO join hook.

It enables training on uneven inputs byshadowing the collective communications in the optimizer step.

Gradients must be properly set before this hook is called.

Parameters

kwargs (dict) – adict containing any keyword argumentsto modify the behavior of the join hook at run time; allJoinable instances sharing the same join contextmanager are forwarded the same value forkwargs.

This hook does not support any keyword arguments; i.e.kwargs isunused.

propertyjoin_process_group:Any#

Return process group.

load_state_dict(state_dict)[source]#

Load the state pertaining to the given rank from the inputstate_dict, updating the local optimizer as needed.

Parameters

state_dict (dict) – optimizer state; should be an object returnedfrom a call tostate_dict().

Raises

RuntimeError – ifoverlap_with_ddp=True and this method is called before thisZeroRedundancyOptimizer instance has been fully initialized, which happens onceDistributedDataParallel gradient buckets have been rebuilt.

state_dict()[source]#

Return the last global optimizer state known to this rank.

Raises

RuntimeError – ifoverlap_with_ddp=True and this method is called before thisZeroRedundancyOptimizer instance has been fully initialized, which happens onceDistributedDataParallel gradient buckets have been rebuilt; or if this method is called without a preceding call toconsolidate_state_dict().

Return type

dict[str,Any]

step(closure=None,**kwargs)[source]#

Perform a single optimizer step and syncs parameters across all ranks.

Parameters

closure (Callable) – a closure that re-evaluates the model andreturns the loss; optional for most optimizers.

Returns

Optional loss depending on the underlying local optimizer.

Return type

Optional[float]

Note

Any extra parameters are passed to the base optimizer as-is.