Rate this Page

DDP Communication Hooks#

Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025

DDP communication hook is a generic interface to control how to communicategradients across workers by overriding the vanilla allreduce inDistributedDataParallel.A few built-in communication hooks are provided,and users can easily apply any of these hooks to optimize communication.Besides, the hook interface can also support user-defined communicationstrategies for more advanced use cases.

How to Use a Communication Hook?#

To use a communication hook, the user just needs to let the DDP model registerthe hook before the training loop as below.

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

What Does a Communication Hook Operate On?#

A communication hook provides a flexible way to allreduce gradients.Therefore, it mainly operates on the gradients on each replica before allreduce,which are bucketized to increase the overlap between communication and computation.Particularly,torch.distributed.GradBucket represents a bucket of gradient tensors to be allreduced.

classtorch.distributed.GradBucket#

This class mainly passes a flattened gradient tensor(returned bybuffer())to DDP communication hook.This tensor can be further decomposed into a list of per-parameter tensors within this bucket(returned byget_per_parameter_tensors())to apply layer-wise operations.

torch.distributed.GradBucket.index(self:torch._C._distributed_c10d.GradBucket)int#

Warning

Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.

Returns:

The index of a bucket that stores gradients of a few contiguous layers.All the gradients are bucketized.

torch.distributed.GradBucket.buffer(self:torch._C._distributed_c10d.GradBucket)torch.Tensor#
Returns:

A flattened 1Dtorch.Tensor buffer,which can be further decomposed into a list of per-parameter tensors within this bucket.

torch.distributed.GradBucket.gradients(self:torch._C._distributed_c10d.GradBucket)list[torch.Tensor]#
Returns:

A list oftorch.Tensor. Each tensor in the list corresponds to a gradient.

torch.distributed.GradBucket.is_last(self:torch._C._distributed_c10d.GradBucket)bool#
Returns:

Whether this bucket is the last bucket to allreduce in an iteration.This also means that this bucket corresponds to the first few layers in the forward pass.

torch.distributed.GradBucket.set_buffer(self:torch._C._distributed_c10d.GradBucket,buffer:torch.Tensor)None#

Replaces the tensor in the bucket with the input tensor buffer.

torch.distributed.GradBucket.parameters(self:torch._C._distributed_c10d.GradBucket)list[torch.Tensor]#
Returns:

A list oftorch.Tensor. Each tensor in the list corresponds to a modelparameter.

Default Communication Hooks#

Default communication hooks are simplestateless hooks, so the input stateinregister_comm_hook is either a process group orNone.The inputbucket is atorch.distributed.GradBucket object.

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group,bucket)[source]#

Callallreduce usingGradBucket tensors.

Once gradient tensors are aggregated across all workers, itsthencallback takes the mean and returns the result.

If user registers this DDP communication hook,DDP results is expected to be same as the case where no hook was registered.Hence, this won’t change behavior of DDP and user can use this as a referenceor modify this hook to log useful information or any other purposes whileunaffecting DDP behavior.

Example::
>>>ddp_model.register_comm_hook(process_group,allreduce_hook)
Return type:

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group,bucket)[source]#

Compress by castingGradBucket totorch.float16 divided by process group size.

This DDP communication hook implements a simple gradient compressionapproach that castsGradBucket tensor to half-precision floating-point format (torch.float16)and then divides it by the process group size.It allreduces thosefloat16 gradient tensors. Once compressed gradienttensors are allreduced, the chained callbackdecompress casts it back to the input data type (such asfloat32).

Example::
>>>ddp_model.register_comm_hook(process_group,fp16_compress_hook)
Return type:

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group,bucket)[source]#

Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

This DDP communication hook implements a simple gradient compressionapproach that castsGradBucket tensor to half-precisionBrain floating point format (torch.bfloat16)and then divides it by the process group size.It allreduces thosebfloat16 gradient tensors. Once compressed gradienttensors are allreduced, the chained callbackdecompress casts it back to the input data type (such asfloat32).

Example::
>>>ddp_model.register_comm_hook(process_group,bf16_compress_hook)
Return type:

Future[Tensor]

Additionally, a communication hook wrapper is provided to supportfp16_compress_hook() orbf16_compress_hook() as a wrapper,which can be combined with other communication hooks.

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]#

Cast input tensor totorch.float16, cast result of hook back to input dtype.

This wrapper casts the input gradient tensor of a given DDP communication hook to half-precisionfloating point format (torch.float16), and casts the resulting tensor of the given hook back tothe input data type, such asfloat32.Therefore,fp16_compress_hook is equivalent tofp16_compress_wrapper(allreduce_hook).

Example::
>>>state=PowerSGDState(process_group=process_group,matrix_approximation_rank=1,start_powerSGD_iter=10)>>>ddp_model.register_comm_hook(state,fp16_compress_wrapper(powerSGD_hook))
Return type:

Callable[[Any,GradBucket],Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]#

Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

This wrapper casts the input gradient tensor of a given DDP communication hook to half-precisionBrain floating point format (torch.bfloat16),and casts the resulting tensor of the given hook back to the input data type, such asfloat32.

Therefore,bf16_compress_hook is equivalent tobf16_compress_wrapper(allreduce_hook).

Example::
>>>state=PowerSGDState(process_group=process_group,matrix_approximation_rank=1,start_powerSGD_iter=10)>>>ddp_model.register_comm_hook(state,bf16_compress_wrapper(powerSGD_hook))
Return type:

Callable[[Any,GradBucket],Future[Tensor]]

PowerSGD Communication Hook#

PowerSGD (Vogels et al., NeurIPS 2019)is a gradient compression algorithm, which can provide very high compressionrates and accelerate bandwidth-bound distributed training.This algorithm needs to maintain both some hyperparameters and the internalstate. Therefore, PowerSGD communication hook is astateful hook,and the user needs to provide a state object defined as below.

PowerSGD State#

classtorch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group,matrix_approximation_rank=1,start_powerSGD_iter=1000,min_compression_rate=2,use_error_feedback=True,warm_start=True,orthogonalization_epsilon=0,random_seed=0,compression_stats_logging_frequency=10000,batch_tensors_with_same_shape=False)[source]#

Store both the algorithm’s hyperparameters and internal state for all gradients during training.

Particularly,matrix_approximation_rank andstart_powerSGD_iter are the main hyperparameters that should be tuned by the user.For performance, we suggest to keep binary hyperparametersuse_error_feedback andwarm_start on.

  1. matrix_approximation_rank controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.

    1.1. Ifmatrix_approximation_rank is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.

    1.2. The increase ofmatrix_approximation_rank can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certainmatrix_approximation_rank threshold.

To tunematrix_approximation_rank, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, …), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.

  1. start_powerSGD_iter defers PowerSGD compression until stepstart_powerSGD_iter, and vanilla allreduce runs prior to stepstart_powerSGD_iter. This hybrid scheme ofvanilla allreduce + PowerSGD can effectively improve the accuracy, even a relatively smallmatrix_approximation_rank is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.

To tunestart_powerSGD_iter, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training,start_powerSGD_iter typically should be no less than the number of warm-up steps.

  1. min_compression_rate is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where(num_rows+num_cols)*matrix_approximation_rank*min_compression_rate<num_rows*num_cols. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.

Compression statistics are logged everycompression_stats_logging_frequency iterations once PowerSGD compression starts.

  1. orthogonalization_epsilon can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.

  2. batch_tensors_with_same_shape controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e.,bucket_cap_mb arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set toTrue if the compression / decompression computation is a bottleneck.

Warning

If error feedback or warm-up is enabled, the minimum value ofstart_powerSGD_iter allowed in DDP is 2.This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,and this can conflict with any tensor memorized before the rebuild process.

PowerSGD Hooks#

Warning

PowerSGD typically requires extra memory of the same size as the model’sgradients to enable error feedback, which can compensate for biasedcompressed communication and improve accuracy.

Warning

PowerSGD hooks may conflict withApex automatic mixed precision package.Please use PyTorchnative automatic mixed precision packageinstead.

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state,bucket)[source]#

Implement PowerSGD algorithm.

This DDP communication hook implements PowerSGD gradient compressionalgorithm described in thepaper.Once gradient tensors are aggregated across all workers, this hook appliescompression as follows:

  1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:

    1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.

    1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).

  2. Handles uncompressed tensors:

    2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;

    2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.

  3. Handles the tensors that should be compressed by PowerSGD compression:

    3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M,such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;

    3.2. Computes each P in Ps, which is equal to MQ;

    3.3. Allreduces Ps as a batch;

    3.4. Orthogonalizes each P in Ps;

    3.5. Computes each Q in Qs, which is approximately equal to M^TP;

    3.6. Allreduces Qs as a batch;

    3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.

Note that this communication hook enforces vanilla allreduce for the firststate.start_powerSGD_iter iterations.This not only gives the user more control over the tradeoff between speedup and accuracy,but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters:
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc.To tune the compression configs, mainly need to tunematrix_approximation_rank,start_powerSGD_iterandmin_compression_rate.

  • bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.Note that since DDP comm hook only supports single process single device mode,only exactly one tensor is stored in this bucket.

Returns:

Future handler of the communication, which updates the gradients in place.

Return type:

Future[Tensor]

Example::
>>>state=PowerSGDState(process_group=process_group,matrix_approximation_rank=1,                          start_powerSGD_iter=10, min_compression_rate=0.5)>>>ddp_model.register_comm_hook(state,powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state,bucket)[source]#

Implement simplified PowerSGD algorithm.

This DDP communication hook implements a simplified PowerSGD gradient compressionalgorithm described in thepaper.This variant does not compress the gradients layer by layer,but instead compresses the flattened input tensor that batches all the gradients.Therefore, it isfaster thanpowerSGD_hook(),but usually results in amuch lower accuracy, unlessmatrix_approximation_rank is 1.

Warning

Increasingmatrix_approximation_rank here may not necessarily increase the accuracy,because batching per-parameter tensors without column/row alignment can destroy low-rank structure.Therefore, the user should always considerpowerSGD_hook() first,and only consider this variant when a satisfactory accuracy can be achieved whenmatrix_approximation_rank is 1.

Once gradient tensors are aggregated across all workers, this hook appliescompression as follows:

  1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;

  2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;

  3. Computes P, which is equal to MQ;

  4. Allreduces P;

  5. Orthogonalizes P;

  6. Computes Q, which is approximately equal to M^TP;

  7. Allreduces Q;

  8. Computes M, which is approximately equal to PQ^T.

  9. Truncates the input tensor to the original length.

Note that this communication hook enforces vanilla allreduce for the firststate.start_powerSGD_iter iterations.This not only gives the user more control over the tradeoff between speedup and accuracy,but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters:
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc.To tune the compression configs, mainly need to tunematrix_approximation_rank andstart_powerSGD_iter.

  • bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.Note that since DDP comm hook only supports single process single device mode,only exactly one tensor is stored in this bucket.

Returns:

Future handler of the communication, which updates the gradients in place.

Return type:

Future[Tensor]

Example::
>>>state=PowerSGDState(process_group=process_group,matrix_approximation_rank=1)>>>ddp_model.register_comm_hook(state,batched_powerSGD_hook)

Debugging Communication Hooks#

As the name implies, debugging communication hooks areonly used for debugging and performance optimization purpose.

Warning

Debugging communication hooks do not necessarily output the correct results.

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_,bucket)[source]#

Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.

This hook shouldonly be used for headroom analysis of allreduce optimization,instead of the normal gradient synchronization.For example, if only less than 10% speedup of training time can be observed after this hook is registered,it usually implies that allreduce is not a performance bottleneck for this case.Such instrumentation can be particularly usefulif GPU traces cannot be easily retrieved or the trace analysis is complicatedsome factors such as the overlap between allreduce and computation or the desynchronization across ranks.

Example::
>>>ddp_model.register_comm_hook(None,noop_hook)
Return type:

Future[Tensor]

Checkpointing of Communication Hooks#

A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts.To make a hook serializable,__setstate__ and__getstate__ should be defined.

Warning

__getstate__ should exclude non-serializable attributes from a returned dictionary.

Warning

__setstate__ should properly initialize non-serializable attributes, excluded from a providedstate.

PowerSGDState has__setstate__ and__getstate__ implemented and can be used as a reference.

classtorch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
__getstate__()[source]#

Return aDict[str,Any] which will be pickled and saved.

process_group is not serializable and excluded froma returned state.

__setstate__(state)[source]#

Take a providedstate and set to thisPowerSGDState instance.

process_group is set to default.

Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.

importosimportsysimporttempfileimporttorchimporttorch.distributedasdistimporttorch.nnasnnimporttorch.optimasoptimimporttorch.multiprocessingasmpfromtorch.nn.parallelimportDistributedDataParallelfromtorch.distributed.algorithms.ddp_comm_hooksimportpowerSGD_hookaspowerSGDclassSimpleModel(nn.Module):def__init__(self):super().__init__()self.fc1=nn.Linear(24,24)self.relu=nn.ReLU()self.fc2=nn.Linear(24,12)defforward(self,x):returnself.fc2(self.relu(self.fc1(x)))defsetup(rank,world_size):os.environ['MASTER_ADDR']='localhost'os.environ['MASTER_PORT']='12355'# initialize the process groupdist.init_process_group("nccl",rank=rank,world_size=world_size)defcleanup():dist.destroy_process_group()defrun_demo(demo_fn,world_size):mp.spawn(demo_fn,args=(world_size,),nprocs=world_size,join=True)defdemo_serialization(rank,world_size):setup(rank,world_size)CHECKPOINT=tempfile.gettempdir()+"/checkpoint.pt"model=SimpleModel().to(rank)ddp_model=DistributedDataParallel(model,device_ids=[rank])powersgd_hook=powerSGD.powerSGD_hookpowersgd_state=powerSGD.PowerSGDState(process_group=None)optimizer=optim.SGD(ddp_model.parameters(),lr=0.001)ddp_model.register_comm_hook(powersgd_state,powersgd_hook)state={'state_dict':ddp_model.state_dict(),'comm_hook':powersgd_hook,'comm_hook_state':powersgd_state}ifrank==0:torch.save(state,CHECKPOINT)dist.barrier()map_location={'cuda:%d'%0:'cuda:%d'%rank}checkpoint=torch.load(CHECKPOINT,map_location=map_location)new_ddp_model=DistributedDataParallel(SimpleModel().to(rank),device_ids=[rank])new_ddp_model.load_state_dict(checkpoint['state_dict'])powersgd_hook=checkpoint['comm_hook']powersgd_state=checkpoint['comm_hook_state']new_ddp_model.register_comm_hook(powersgd_state,powersgd_hook)ifrank==0:os.remove(CHECKPOINT)cleanup()if__name__=="__main__":n_gpus=torch.cuda.device_count()assertn_gpus>=2,f"Requires at least 2 GPUs to run, but got{n_gpus}"world_size=n_gpusrun_demo(demo_serialization,world_size)

Acknowledgements#

Many thanks to PowerSGD paper authorThijs Vogels for the code review onPowerSGD communication hook, as well as thecomparison experiments,which show that the performance of PowerSGD communication hook is on par withthe implementation in the originalpaper.