Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6ac4db0

Browse files
committed
Merge branch 'dnarayanan/dist_optimizer_refactor' into 'main'
Refactor distributed optimizer communication code into megatron/core/distributedSee merge request ADLR/megatron-lm!1975
2 parents5747146 +655a663 commit6ac4db0

File tree

13 files changed

+490
-501
lines changed

13 files changed

+490
-501
lines changed

‎megatron/core/distributed/__init__.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@
33
from .distributed_data_parallelimportDistributedDataParallel
44
from .distributed_data_parallel_configimportDistributedDataParallelConfig
55
from .finalize_model_gradsimportfinalize_model_grads
6-
from .param_and_grad_bufferimportParamAndGradBuffer,partition_buckets,shard_buffer
6+
7+
# For backwards compatibility. ParamAndGradBuffer will be deprecated in future release.
8+
# ParamAndGradBuffer (which is an alias of _ParamAndGradBuffer) is not intended to be
9+
# consumed directly by external code.
10+
from .param_and_grad_bufferimportParamAndGradBuffer

‎megatron/core/distributed/distributed_data_parallel.py‎

Lines changed: 124 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
importlogging
44
fromcontextlibimportcontextmanager
5-
fromtypingimportDict
65

76
importtorch
87

@@ -12,7 +11,7 @@
1211
from ..transformer.transformer_configimportTransformerConfig
1312
from ..utilsimportis_float8tensor,log_single_rank
1413
from .distributed_data_parallel_configimportDistributedDataParallelConfig
15-
from .param_and_grad_bufferimportBucketGroup,ParamAndGradBuffer,partition_buckets
14+
from .param_and_grad_bufferimport_ParamAndGradBuffer,partition_buckets
1615

1716
logger=logging.getLogger(__name__)
1817

@@ -77,7 +76,6 @@ def __init__(
7776
ifdisable_bucketing:
7877
self.bucket_size=None
7978

80-
self.module=module
8179
self.param_to_bucket_group= {}
8280

8381
# Group parameters by their gradient type.
@@ -101,7 +99,7 @@ def __init__(
10199
else:
102100
expert_parallel_params.append(param)
103101

104-
defallocate_buffers_for_parameters(
102+
def_allocate_buffers_for_parameters(
105103
input_params,data_parallel_group,gradient_scaling_factor
106104
):
107105
param_and_grad_dtype_to_params= {}
@@ -110,8 +108,7 @@ def allocate_buffers_for_parameters(
110108

111109
# Group parameters by their gradient type.
112110
forparamininput_params:
113-
ifnotparam.requires_grad:
114-
continue
111+
assertparam.requires_grad
115112

116113
param_dtype=param.dtype
117114
ifis_float8tensor(param):
@@ -167,7 +164,7 @@ def allocate_buffers_for_parameters(
167164
buffers= []
168165
for (param_dtype,grad_dtype),paramsinparam_and_grad_dtype_to_params.items():
169166
buffers.append(
170-
ParamAndGradBuffer(
167+
_ParamAndGradBuffer(
171168
self.ddp_config,
172169
param_dtype,
173170
grad_dtype,
@@ -187,9 +184,20 @@ def allocate_buffers_for_parameters(
187184
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
188185
# communications will prevent the overlap of the communication kernels with computation
189186
# kernels.
190-
bucket_groups=partition_buckets(buffers)
187+
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
188+
# bucket group.
189+
bucket_groups=partition_buckets(buffers,force_single_bucket_group=disable_bucketing)
190+
191+
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
192+
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
193+
ifself.ddp_config.use_distributed_optimizerandself.ddp_config.overlap_param_gather:
194+
num_bucket_groups=len(bucket_groups)
195+
foriinrange(1,num_bucket_groups):
196+
bucket_groups[num_bucket_groups-i].next_param_gather_bucket_group= (
197+
bucket_groups[num_bucket_groups-i-1]
198+
)
191199

192-
# Create map from param toBucketGroup, used in pre_hook.
200+
# Create map from param tobucket group, used in pre_hook.
193201
forbucket_groupinbucket_groups:
194202
forbucketinbucket_group.buckets:
195203
forparaminbucket.params_list:
@@ -214,15 +222,15 @@ def allocate_buffers_for_parameters(
214222
expert_gradient_scaling_factor=1.0/data_parallel_world_size
215223

216224
# Allocate the param+grad buffers for dense params' grads.
217-
self.buffers,self.bucket_groups=allocate_buffers_for_parameters(
225+
self.buffers,self.bucket_groups=_allocate_buffers_for_parameters(
218226
dense_params,
219227
parallel_state.get_data_parallel_group(with_context_parallel=True),
220228
gradient_scaling_factor=gradient_scaling_factor,
221229
)
222230

223231
# Allocate separate param+grad buffers for expert parallel params' grads.
224232
self.expert_parallel_buffers,self.expert_parallel_bucket_groups= (
225-
allocate_buffers_for_parameters(
233+
_allocate_buffers_for_parameters(
226234
expert_parallel_params,
227235
parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
228236
gradient_scaling_factor=expert_gradient_scaling_factor,
@@ -252,26 +260,93 @@ def unmap_weight_tensor(m):
252260
param_tmp=param.expand_as(param)
253261
# Get the gradient accumulator function.
254262
grad_acc=param_tmp.grad_fn.next_functions[0][0]
255-
grad_acc.register_hook(self._make_param_hook(param,self.param_to_bucket_group))
263+
grad_acc.register_hook(self._make_backward_post_hook(param))
256264
self.grad_accs.append(grad_acc)
257265

266+
self.use_forward_hook= (
267+
self.ddp_config.use_distributed_optimizerandself.ddp_config.overlap_param_gather
268+
)
269+
self.remove_forward_pre_hook_handles= {}
270+
ifself.use_forward_hook:
271+
self.enable_forward_pre_hook()
272+
self.overlap_param_gather_with_optimizer_step=False
273+
274+
defenable_forward_pre_hook(self):
275+
"""
276+
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
277+
"""
278+
assertself.use_forward_hook
279+
assertlen(self.remove_forward_pre_hook_handles)==0
280+
# Register forward pre-hook for all sub-modules.
281+
formoduleinself.module.modules():
282+
self.remove_forward_pre_hook_handles[module]=module.register_forward_pre_hook(
283+
self._make_forward_pre_hook()
284+
)
285+
286+
defdisable_forward_pre_hook(self):
287+
"""
288+
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
289+
"""
290+
assertself.use_forward_hook
291+
# De-register forward pre-hook for all sub-modules.
292+
formoduleinself.module.modules():
293+
assertself.remove_forward_pre_hook_handles[module]isnotNone
294+
self.remove_forward_pre_hook_handles[module].remove()
295+
delself.remove_forward_pre_hook_handles[module]
296+
assertlen(self.remove_forward_pre_hook_handles)==0
297+
298+
# Force synchronize parameters.
299+
self.start_param_sync(force_sync=True)
300+
258301
defforward(self,*inputs,**kwargs):
259302
"""
260303
Calls the wrapped module's forward() method.
261304
"""
262305
returnself.module(*inputs,**kwargs)
263306

264-
def_make_param_hook(
265-
self,
266-
param:torch.nn.Parameter,
267-
param_to_bucket_group:Dict[torch.nn.Parameter,BucketGroup],
268-
):
307+
def_make_forward_pre_hook(self):
269308
"""
270-
Creates the all-reduce / reduce-scatter hook for backprop.
309+
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
310+
when a module uses a parameter in a bucket with a still incomplete all-gather).
271311
"""
272312

273-
defparam_hook(*unused):
274-
ifparam.requires_grad:
313+
defhook(module,*unused):
314+
assert (
315+
self.use_forward_hook
316+
),"Should use pre-hook only when overlap_param_gather is True"
317+
318+
# Make sure all parameters in this module have been all-gathered as necessary.
319+
forparaminmodule.parameters(recurse=False):
320+
# Skip parameters without an associated buffer (such parameters have a
321+
# .requires_grad field equal to False).
322+
ifparamnotinself.param_to_bucket_group:
323+
continue
324+
assertparam.requires_grad
325+
326+
# If aligning param all-gather across pipeline stages, all-gather is dispatched
327+
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
328+
# If overlapping param all-gather with optimizer step, then all-gather has
329+
# already been dispatched in optimizer step.
330+
skip_next_bucket_dispatch= (
331+
self.ddp_config.align_param_gather
332+
orself.overlap_param_gather_with_optimizer_step
333+
)
334+
self.param_to_bucket_group[param].finish_param_sync(
335+
skip_next_bucket_dispatch=skip_next_bucket_dispatch
336+
)
337+
338+
returnhook
339+
340+
def_make_backward_post_hook(self,param:torch.nn.Parameter):
341+
"""
342+
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
343+
ready (i.e., when all grads in a bucket have been computed in all microbatches
344+
in a batch).
345+
"""
346+
347+
defhook(*unused):
348+
ifparaminself.param_to_bucket_group:
349+
assertparam.requires_grad
275350
ifself.ddp_config.overlap_grad_reduce:
276351
assert (
277352
param.gradisnotNone
@@ -283,9 +358,9 @@ def param_hook(*unused):
283358
param.grad=None
284359

285360
ifself.ddp_config.overlap_grad_reduce:
286-
param_to_bucket_group[param].register_grad_ready(param)
361+
self.param_to_bucket_group[param].register_grad_ready(param)
287362

288-
returnparam_hook
363+
returnhook
289364

290365
@contextmanager
291366
defno_sync(self):
@@ -300,6 +375,28 @@ def no_sync(self):
300375
forbucket_groupinself.bucket_groups+self.expert_parallel_bucket_groups:
301376
bucket_group.is_last_microbatch=True
302377

378+
defstart_param_sync(self,*unused,force_sync:bool=False,force_dispatch:bool=False):
379+
"""
380+
Initiates param sync (all-gather) communication operations for all model parameters.
381+
382+
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
383+
calls; when overlap_param_gather is set to False, calls synchronous communication
384+
ops. Can override this default behavior using flags below.
385+
386+
Args:
387+
force_sync (bool, optional): force synchronous collective regardless of
388+
other settings.
389+
force_dispatch (bool, optional): force dispatch regardless of other settings.
390+
"""
391+
ifnotforce_sync:
392+
# If overlapping param AG with optimizer step, AG should not be dispatched again
393+
# in forward_backward_step.
394+
ifself.overlap_param_gather_with_optimizer_stepandnotforce_dispatch:
395+
return
396+
397+
forbucket_groupinself.bucket_groups+self.expert_parallel_bucket_groups:
398+
bucket_group.start_param_sync(force_sync=force_sync)
399+
303400
defstart_grad_sync(self,*unused):
304401
"""
305402
Initiates grad sync (all-reduce or reduce-scatter) communication operations
@@ -312,11 +409,6 @@ def start_grad_sync(self, *unused):
312409
forbucket_groupinself.bucket_groups+self.expert_parallel_bucket_groups:
313410
bucket_group.start_grad_sync()
314411

315-
defscale_gradients(self,scaling_factor:float)->None:
316-
"""Scale all gradients inside the buffers by `scaling_factor`."""
317-
forbufferinself.buffers+self.expert_parallel_buffers:
318-
buffer.scale_gradients(scaling_factor)
319-
320412
deffinish_grad_sync(self):
321413
"""
322414
Finishes grad sync (all-reduce or reduce-scatter) communication operations
@@ -329,6 +421,11 @@ def finish_grad_sync(self):
329421
forbucket_groupinself.bucket_groups+self.expert_parallel_bucket_groups:
330422
bucket_group.finish_grad_sync()
331423

424+
defscale_gradients(self,scaling_factor:float):
425+
"""Scale all gradients inside the buffers by `scaling_factor`."""
426+
forbufferinself.buffers+self.expert_parallel_buffers:
427+
buffer.scale_gradients(scaling_factor)
428+
332429
defzero_grad_buffer(self):
333430
"""
334431
Zeros out all grad buffers. Needs to be called at the beginning of each

‎megatron/core/distributed/distributed_data_parallel_config.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ class DistributedDataParallelConfig:
1414
overlap_grad_reduce:bool=False
1515
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
1616

17+
overlap_param_gather:bool=False
18+
"""If true, overlap param all-gather with forward compute."""
19+
20+
align_param_gather:bool=False
21+
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
22+
PP stage will independently launch as needed.
23+
"""
24+
1725
use_distributed_optimizer:bool=False
1826
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
1927
originally allocated model parameters, otherwise issue all-reduce collectives.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp