Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf70bd71

Browse files
Andrew Gupytorchmergebot
Andrew Gu
authored andcommitted
[FSDP2] Computed grad divide factors at runtime (#125484)
**Context**We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.)Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.- After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch.- Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.- Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$.Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients.- For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.- For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`).**Additional Notes**How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.Pull Requestresolved:#125484Approved by:https://github.com/wanchaolghstack dependencies:#125431,#125479
1 parentdba689b commitf70bd71

File tree

3 files changed

+59
-64
lines changed

3 files changed

+59
-64
lines changed

‎test/distributed/_composable/fsdp/test_fully_shard_comm.py‎

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
OffloadPolicy,
1919
)
2020
fromtorch.distributed._composable.fsdp._fsdp_collectivesimport (
21+
_div_if_needed,
22+
_get_gradient_divide_factors,
2123
foreach_all_gather,
2224
foreach_all_gather_copy_out,
2325
foreach_reduce,
@@ -207,6 +209,18 @@ def test_reduce_scatter_fp32(self):
207209
reduce_scatter_dtype=torch.float32,
208210
)
209211

212+
@unittest.skipIf(notTEST_CUDA,"no cuda")
213+
deftest_reduce_scatter_fp16(self):
214+
param_sizes=self._get_param_sizes()
215+
default_stream=torch.cuda.current_stream()
216+
stream=torch.cuda.Stream()
217+
forreduce_scatter_streamin (default_stream,stream):
218+
self._test_reduce_scatter(
219+
param_sizes,
220+
reduce_scatter_stream=reduce_scatter_stream,
221+
reduce_scatter_dtype=torch.float16,
222+
)
223+
210224
def_test_reduce_scatter(
211225
self,
212226
param_sizes:List[torch.Size],
@@ -238,17 +252,24 @@ def _test_reduce_scatter(
238252
orig_dtype=orig_params[0].dtype,
239253
reduce_dtype=reduce_scatter_dtype,
240254
device=self.device,
241-
divide_factors=fsdp_param_group._grad_divide_factors,
242255
all_reduce_group=None,
243256
all_reduce_stream=all_reduce_stream,
244257
)
245258
torch.cuda.current_stream().wait_event(view_out_event)
246259

247260
# Check reduce-scatter correctness
261+
predivide_factor,postdivide_factor=_get_gradient_divide_factors(
262+
group,None,reduce_scatter_dtype
263+
)
248264
reduced_grads= [grad.detach().clone()forgradinunsharded_grads]
249265
forgradinreduced_grads:
250-
dist.all_reduce(grad,group=group)
251-
grad/=self.world_size
266+
_div_if_needed(grad,predivide_factor)
267+
dist.all_reduce(
268+
grad,
269+
group=group,
270+
op=dist.ReduceOp.AVGifpredivide_factorisNoneelsedist.ReduceOp.SUM,
271+
)
272+
_div_if_needed(grad,postdivide_factor)
252273
forfsdp_param,reduced_gradinzip(fsdp_params,reduced_grads):
253274
sharded_grad=fsdp_param.sharded_param.grad
254275
self.assertIsInstance(sharded_grad,DTensor)

‎torch/distributed/_composable/fsdp/_fsdp_collectives.py‎

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def foreach_reduce(
125125
orig_dtype:torch.dtype,
126126
reduce_dtype:Optional[torch.dtype],
127127
device:torch.device,
128-
divide_factors:Union[Tuple[None,None],Tuple[float,float]],
129128
all_reduce_group:Optional[dist.ProcessGroup],
130129
all_reduce_stream:torch.cuda.Stream,
131130
)->torch.cuda.Event:
@@ -142,7 +141,9 @@ def foreach_reduce(
142141
)
143142
grad_dtype=unsharded_grads[0].dtype
144143
reduce_dtype=reduce_dtypeorgrad_dtype
145-
predivide_factor,postdivide_factor=divide_factors
144+
predivide_factor,postdivide_factor=_get_gradient_divide_factors(
145+
reduce_scatter_group,all_reduce_group,reduce_dtype
146+
)
146147
world_size=reduce_scatter_group.size()
147148
padded_unsharded_sizes=tuple(
148149
_get_dim0_padded_size(grad.size(),world_size)forgradinunsharded_grads
@@ -166,18 +167,22 @@ def foreach_reduce(
166167
(reduce_scatter_output_numel,)
167168
)
168169
_div_if_needed(reduce_scatter_input,predivide_factor)
169-
_reduce_scatter(
170-
post_reduce_output,
171-
reduce_scatter_input,
172-
reduce_scatter_group,
173-
divide_factors,
170+
dist.reduce_scatter_tensor(
171+
output=post_reduce_output,
172+
input=reduce_scatter_input,
173+
group=reduce_scatter_group,
174+
op=ReduceOp.AVGifpredivide_factorisNoneelseReduceOp.SUM,
174175
)
175176
view_out_stream=reduce_scatter_stream
176177
ifall_reduce_groupisnotNone:
177178
view_out_stream=all_reduce_stream
178179
all_reduce_stream.wait_stream(reduce_scatter_stream)
179180
withtorch.cuda.stream(all_reduce_stream):
180-
_all_reduce(post_reduce_output,all_reduce_group,divide_factors)
181+
dist.all_reduce(
182+
post_reduce_output,
183+
group=all_reduce_group,
184+
op=ReduceOp.AVGifpredivide_factorisNoneelseReduceOp.SUM,
185+
)
181186
withtorch.cuda.stream(view_out_stream):
182187
_div_if_needed(post_reduce_output,postdivide_factor)
183188
post_reduce_output=_to_dtype_if_needed(post_reduce_output,orig_dtype)
@@ -257,30 +262,27 @@ def _get_all_gather_input_metadatas(
257262
)
258263

259264

260-
def_reduce_scatter(
261-
output:torch.Tensor,
262-
input:torch.Tensor,
263-
group:dist.ProcessGroup,
264-
divide_factors:Union[Tuple[None,None],Tuple[float,float]],
265-
)->None:
266-
ifdivide_factors[0]:
267-
dist.reduce_scatter_tensor(output,input,group=group)
268-
else:
269-
# Using NCCL's reduce-scatter to do the division by world size saves
270-
# extra memory read/write from a separate division kernel
271-
dist.reduce_scatter_tensor(output,input,op=ReduceOp.AVG,group=group)
272-
273-
274-
def_all_reduce(
275-
tensor:torch.Tensor,
276-
group:dist.ProcessGroup,
277-
divide_factors:Union[Tuple[None,None],Tuple[float,float]],
278-
)->None:
279-
ifdivide_factors[0]:
280-
dist.all_reduce(tensor,group=group)
281-
else:
282-
# saves extra memory read/write from a separate division kernel
283-
dist.all_reduce(tensor,op=ReduceOp.AVG,group=group)
265+
def_get_gradient_divide_factors(
266+
reduce_scatter_group:dist.ProcessGroup,
267+
all_reduce_group:Optional[dist.ProcessGroup],
268+
reduce_dtype:torch.dtype,
269+
)->Union[Tuple[None,None],Tuple[float,float]]:
270+
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
271+
# use NCCL's built-in division to avoid separate div kernels
272+
ifreduce_dtypein (torch.float32,torch.bfloat16):
273+
returnNone,None
274+
data_parallel_size=reduce_scatter_group.size()
275+
ifall_reduce_groupisnotNone:
276+
data_parallel_size*=all_reduce_group.size()
277+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
278+
# overflow/underflow. For N data parallel workers, each worker computes
279+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
280+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
281+
factor:int=1
282+
whiledata_parallel_size%factor==0anddata_parallel_size/factor>factor:
283+
factor*=2
284+
factor=float(factor)
285+
return (factor,data_parallel_size/factor)
284286

285287

286288
def_div_if_needed(tensor:torch.Tensor,div_factor:Optional[float])->None:

‎torch/distributed/_composable/fsdp/_fsdp_param_group.py‎

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
importcontextlib
22

3-
fromtypingimportAny,cast,Dict,List,NamedTuple,Optional,Set,Tuple,Union
3+
fromtypingimportAny,cast,Dict,List,NamedTuple,Optional,Set,Tuple
44

55
importtorch
66
importtorch.distributedasdist
@@ -164,32 +164,6 @@ def _init_mp_dtypes(self) -> None:
164164
)
165165
self._reduce_dtype=next(iter(reduce_dtypes))
166166

167-
def_init_grad_divide_factors(self):
168-
data_parallel_world_size=1
169-
data_parallel_world_size*=self.mesh_info.shard_mesh_size
170-
ifself._is_hsdp:
171-
data_parallel_world_size*=self.mesh_info.replicate_mesh_size
172-
ifself._reduce_dtypein (torch.float32,torch.bfloat16):
173-
# Use NCCL's AVG op to divide after reduction since it is more
174-
# performant and fp32 has sufficient precision
175-
self._grad_divide_factors:Union[Tuple[None,None],Tuple[float,float]]= (
176-
None,
177-
None,
178-
)
179-
return
180-
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
181-
# overflow/underflow. For N data parallel workers, each worker computes
182-
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
183-
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
184-
factor:int=1
185-
while (
186-
data_parallel_world_size%factor==0
187-
anddata_parallel_world_size/factor>factor
188-
):
189-
factor*=2
190-
factor=float(factor)
191-
self._grad_divide_factors= (factor,data_parallel_world_size/factor)
192-
193167
deflazy_init(self):
194168
# Lazy init should be idempotent
195169
param_names_on_meta= [
@@ -207,7 +181,6 @@ def lazy_init(self):
207181
# Initialize mixed precision attributes lazily in case the user changes
208182
# the parameter dtypes after construction time but before forward
209183
self._init_mp_dtypes()
210-
self._init_grad_divide_factors()
211184
self._register_state_dict_hooks()
212185

213186
# Runtime #
@@ -346,7 +319,6 @@ def post_backward(self, *unused: Any):
346319
self._orig_dtype,
347320
self._reduce_dtype,
348321
self.device,
349-
self._grad_divide_factors,
350322
self._all_reduce_process_group
351323
ifself._is_hsdpandself.all_reduce_grads
352324
elseNone,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp