- Notifications
You must be signed in to change notification settings - Fork25.8k
Commitf70bd71
[FSDP2] Computed grad divide factors at runtime (#125484)
**Context**We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.)Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.- After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch.- Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.- Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$.Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients.- For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.- For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`).**Additional Notes**How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.Pull Requestresolved:#125484Approved by:https://github.com/wanchaolghstack dependencies:#125431,#1254791 parentdba689b commitf70bd71
File tree
3 files changed
+59
-64
lines changed- test/distributed/_composable/fsdp
- torch/distributed/_composable/fsdp
3 files changed
+59
-64
lines changedLines changed: 24 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
| 22 | + | |
21 | 23 | | |
22 | 24 | | |
23 | 25 | | |
| |||
207 | 209 | | |
208 | 210 | | |
209 | 211 | | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
210 | 224 | | |
211 | 225 | | |
212 | 226 | | |
| |||
238 | 252 | | |
239 | 253 | | |
240 | 254 | | |
241 | | - | |
242 | 255 | | |
243 | 256 | | |
244 | 257 | | |
245 | 258 | | |
246 | 259 | | |
247 | 260 | | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
248 | 264 | | |
249 | 265 | | |
250 | | - | |
251 | | - | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
252 | 273 | | |
253 | 274 | | |
254 | 275 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
125 | 125 | | |
126 | 126 | | |
127 | 127 | | |
128 | | - | |
129 | 128 | | |
130 | 129 | | |
131 | 130 | | |
| |||
142 | 141 | | |
143 | 142 | | |
144 | 143 | | |
145 | | - | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
146 | 147 | | |
147 | 148 | | |
148 | 149 | | |
| |||
166 | 167 | | |
167 | 168 | | |
168 | 169 | | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
174 | 175 | | |
175 | 176 | | |
176 | 177 | | |
177 | 178 | | |
178 | 179 | | |
179 | 180 | | |
180 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
181 | 186 | | |
182 | 187 | | |
183 | 188 | | |
| |||
257 | 262 | | |
258 | 263 | | |
259 | 264 | | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
284 | 286 | | |
285 | 287 | | |
286 | 288 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
| 3 | + | |
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| |||
164 | 164 | | |
165 | 165 | | |
166 | 166 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | 167 | | |
194 | 168 | | |
195 | 169 | | |
| |||
207 | 181 | | |
208 | 182 | | |
209 | 183 | | |
210 | | - | |
211 | 184 | | |
212 | 185 | | |
213 | 186 | | |
| |||
346 | 319 | | |
347 | 320 | | |
348 | 321 | | |
349 | | - | |
350 | 322 | | |
351 | 323 | | |
352 | 324 | | |
| |||
0 commit comments
Comments
(0)