22
33import logging
44from contextlib import contextmanager
5- from typing import Dict
65
76import torch
87
1211from ..transformer .transformer_config import TransformerConfig
1312from ..utils import is_float8tensor ,log_single_rank
1413from .distributed_data_parallel_config import DistributedDataParallelConfig
15- from .param_and_grad_buffer import BucketGroup , ParamAndGradBuffer ,partition_buckets
14+ from .param_and_grad_buffer import _ParamAndGradBuffer ,partition_buckets
1615
1716logger = logging .getLogger (__name__ )
1817
@@ -77,7 +76,6 @@ def __init__(
7776if disable_bucketing :
7877self .bucket_size = None
7978
80- self .module = module
8179self .param_to_bucket_group = {}
8280
8381# Group parameters by their gradient type.
@@ -101,7 +99,7 @@ def __init__(
10199else :
102100expert_parallel_params .append (param )
103101
104- def allocate_buffers_for_parameters (
102+ def _allocate_buffers_for_parameters (
105103input_params ,data_parallel_group ,gradient_scaling_factor
106104 ):
107105param_and_grad_dtype_to_params = {}
@@ -110,8 +108,7 @@ def allocate_buffers_for_parameters(
110108
111109# Group parameters by their gradient type.
112110for param in input_params :
113- if not param .requires_grad :
114- continue
111+ assert param .requires_grad
115112
116113param_dtype = param .dtype
117114if is_float8tensor (param ):
@@ -167,7 +164,7 @@ def allocate_buffers_for_parameters(
167164buffers = []
168165for (param_dtype ,grad_dtype ),params in param_and_grad_dtype_to_params .items ():
169166buffers .append (
170- ParamAndGradBuffer (
167+ _ParamAndGradBuffer (
171168self .ddp_config ,
172169param_dtype ,
173170grad_dtype ,
@@ -187,9 +184,20 @@ def allocate_buffers_for_parameters(
187184# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
188185# communications will prevent the overlap of the communication kernels with computation
189186# kernels.
190- bucket_groups = partition_buckets (buffers )
187+ # If bucketing is explicitly disabled, then put all buckets in a buffer into a single
188+ # bucket group.
189+ bucket_groups = partition_buckets (buffers ,force_single_bucket_group = disable_bucketing )
190+
191+ # Set `next_param_gather_bucket_group` for different bucket groups by iterating through
192+ # buckets in reverse order (since all-gathers happen in reverse order of buckets).
193+ if self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather :
194+ num_bucket_groups = len (bucket_groups )
195+ for i in range (1 ,num_bucket_groups ):
196+ bucket_groups [num_bucket_groups - i ].next_param_gather_bucket_group = (
197+ bucket_groups [num_bucket_groups - i - 1 ]
198+ )
191199
192- # Create map from param toBucketGroup , used in pre_hook.
200+ # Create map from param tobucket group , used in pre_hook.
193201for bucket_group in bucket_groups :
194202for bucket in bucket_group .buckets :
195203for param in bucket .params_list :
@@ -214,15 +222,15 @@ def allocate_buffers_for_parameters(
214222expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
215223
216224# Allocate the param+grad buffers for dense params' grads.
217- self .buffers ,self .bucket_groups = allocate_buffers_for_parameters (
225+ self .buffers ,self .bucket_groups = _allocate_buffers_for_parameters (
218226dense_params ,
219227parallel_state .get_data_parallel_group (with_context_parallel = True ),
220228gradient_scaling_factor = gradient_scaling_factor ,
221229 )
222230
223231# Allocate separate param+grad buffers for expert parallel params' grads.
224232self .expert_parallel_buffers ,self .expert_parallel_bucket_groups = (
225- allocate_buffers_for_parameters (
233+ _allocate_buffers_for_parameters (
226234expert_parallel_params ,
227235parallel_state .get_data_modulo_expert_parallel_group (with_context_parallel = True ),
228236gradient_scaling_factor = expert_gradient_scaling_factor ,
@@ -252,26 +260,93 @@ def unmap_weight_tensor(m):
252260param_tmp = param .expand_as (param )
253261# Get the gradient accumulator function.
254262grad_acc = param_tmp .grad_fn .next_functions [0 ][0 ]
255- grad_acc .register_hook (self ._make_param_hook (param , self . param_to_bucket_group ))
263+ grad_acc .register_hook (self ._make_backward_post_hook (param ))
256264self .grad_accs .append (grad_acc )
257265
266+ self .use_forward_hook = (
267+ self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather
268+ )
269+ self .remove_forward_pre_hook_handles = {}
270+ if self .use_forward_hook :
271+ self .enable_forward_pre_hook ()
272+ self .overlap_param_gather_with_optimizer_step = False
273+
274+ def enable_forward_pre_hook (self ):
275+ """
276+ Enable forward pre-hooks needed for param all-gather overlap with forward compute.
277+ """
278+ assert self .use_forward_hook
279+ assert len (self .remove_forward_pre_hook_handles )== 0
280+ # Register forward pre-hook for all sub-modules.
281+ for module in self .module .modules ():
282+ self .remove_forward_pre_hook_handles [module ]= module .register_forward_pre_hook (
283+ self ._make_forward_pre_hook ()
284+ )
285+
286+ def disable_forward_pre_hook (self ):
287+ """
288+ Disable forward pre-hooks needed for param all-gather overlap with forward compute.
289+ """
290+ assert self .use_forward_hook
291+ # De-register forward pre-hook for all sub-modules.
292+ for module in self .module .modules ():
293+ assert self .remove_forward_pre_hook_handles [module ]is not None
294+ self .remove_forward_pre_hook_handles [module ].remove ()
295+ del self .remove_forward_pre_hook_handles [module ]
296+ assert len (self .remove_forward_pre_hook_handles )== 0
297+
298+ # Force synchronize parameters.
299+ self .start_param_sync (force_sync = True )
300+
258301def forward (self ,* inputs ,** kwargs ):
259302"""
260303 Calls the wrapped module's forward() method.
261304 """
262305return self .module (* inputs ,** kwargs )
263306
264- def _make_param_hook (
265- self ,
266- param :torch .nn .Parameter ,
267- param_to_bucket_group :Dict [torch .nn .Parameter ,BucketGroup ],
268- ):
307+ def _make_forward_pre_hook (self ):
269308"""
270- Creates the all-reduce / reduce-scatter hook for backprop.
309+ Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
310+ when a module uses a parameter in a bucket with a still incomplete all-gather).
271311 """
272312
273- def param_hook (* unused ):
274- if param .requires_grad :
313+ def hook (module ,* unused ):
314+ assert (
315+ self .use_forward_hook
316+ ),"Should use pre-hook only when overlap_param_gather is True"
317+
318+ # Make sure all parameters in this module have been all-gathered as necessary.
319+ for param in module .parameters (recurse = False ):
320+ # Skip parameters without an associated buffer (such parameters have a
321+ # .requires_grad field equal to False).
322+ if param not in self .param_to_bucket_group :
323+ continue
324+ assert param .requires_grad
325+
326+ # If aligning param all-gather across pipeline stages, all-gather is dispatched
327+ # by start_param_sync calls in core/pipeline_parallelism/schedules.py.
328+ # If overlapping param all-gather with optimizer step, then all-gather has
329+ # already been dispatched in optimizer step.
330+ skip_next_bucket_dispatch = (
331+ self .ddp_config .align_param_gather
332+ or self .overlap_param_gather_with_optimizer_step
333+ )
334+ self .param_to_bucket_group [param ].finish_param_sync (
335+ skip_next_bucket_dispatch = skip_next_bucket_dispatch
336+ )
337+
338+ return hook
339+
340+ def _make_backward_post_hook (self ,param :torch .nn .Parameter ):
341+ """
342+ Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
343+ ready (i.e., when all grads in a bucket have been computed in all microbatches
344+ in a batch).
345+ """
346+
347+ def hook (* unused ):
348+ if param in self .param_to_bucket_group :
349+ assert param .requires_grad
275350if self .ddp_config .overlap_grad_reduce :
276351assert (
277352param .grad is not None
@@ -283,9 +358,9 @@ def param_hook(*unused):
283358param .grad = None
284359
285360if self .ddp_config .overlap_grad_reduce :
286- param_to_bucket_group [param ].register_grad_ready (param )
361+ self . param_to_bucket_group [param ].register_grad_ready (param )
287362
288- return param_hook
363+ return hook
289364
290365@contextmanager
291366def no_sync (self ):
@@ -300,6 +375,28 @@ def no_sync(self):
300375for bucket_group in self .bucket_groups + self .expert_parallel_bucket_groups :
301376bucket_group .is_last_microbatch = True
302377
378+ def start_param_sync (self ,* unused ,force_sync :bool = False ,force_dispatch :bool = False ):
379+ """
380+ Initiates param sync (all-gather) communication operations for all model parameters.
381+
382+ By default, when overlap_param_gather is set to True, dispatches asynchronous communication
383+ calls; when overlap_param_gather is set to False, calls synchronous communication
384+ ops. Can override this default behavior using flags below.
385+
386+ Args:
387+ force_sync (bool, optional): force synchronous collective regardless of
388+ other settings.
389+ force_dispatch (bool, optional): force dispatch regardless of other settings.
390+ """
391+ if not force_sync :
392+ # If overlapping param AG with optimizer step, AG should not be dispatched again
393+ # in forward_backward_step.
394+ if self .overlap_param_gather_with_optimizer_step and not force_dispatch :
395+ return
396+
397+ for bucket_group in self .bucket_groups + self .expert_parallel_bucket_groups :
398+ bucket_group .start_param_sync (force_sync = force_sync )
399+
303400def start_grad_sync (self ,* unused ):
304401"""
305402 Initiates grad sync (all-reduce or reduce-scatter) communication operations
@@ -312,11 +409,6 @@ def start_grad_sync(self, *unused):
312409for bucket_group in self .bucket_groups + self .expert_parallel_bucket_groups :
313410bucket_group .start_grad_sync ()
314411
315- def scale_gradients (self ,scaling_factor :float )-> None :
316- """Scale all gradients inside the buffers by `scaling_factor`."""
317- for buffer in self .buffers + self .expert_parallel_buffers :
318- buffer .scale_gradients (scaling_factor )
319-
320412def finish_grad_sync (self ):
321413"""
322414 Finishes grad sync (all-reduce or reduce-scatter) communication operations
@@ -329,6 +421,11 @@ def finish_grad_sync(self):
329421for bucket_group in self .bucket_groups + self .expert_parallel_bucket_groups :
330422bucket_group .finish_grad_sync ()
331423
424+ def scale_gradients (self ,scaling_factor :float ):
425+ """Scale all gradients inside the buffers by `scaling_factor`."""
426+ for buffer in self .buffers + self .expert_parallel_buffers :
427+ buffer .scale_gradients (scaling_factor )
428+
332429def zero_grad_buffer (self ):
333430"""
334431 Zeros out all grad buffers. Needs to be called at the beginning of each