Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitebedce2

Browse files
weifengpypytorchmergebot
authored andcommitted
[FSDP] enable autograd in forward prefetching (#116792)
**problem**when prefetching for next forward, current forward may be annotated by`@torch.no_grad`. `param.grad_fn` keeps being None during prefetching.`_post_backward_hook` never gets triggeredrepro```pytest test/distributed/fsdp/test_fsdp_freezing_weights.py```**solution**this PR enabled autograd during prefetching (`_use_unsharded_views`), so`param.grad_fn` are properly assigned for next forwarda longer-term fix would be moving `_use_unsharded_views` out of`_prefetch_handle` and put it in `_pre_forward_unshard`Pull Requestresolved:#116792Approved by:https://github.com/awgu
1 parent7f12416 commitebedce2

File tree

2 files changed

+105
-24
lines changed

2 files changed

+105
-24
lines changed

‎test/distributed/fsdp/test_fsdp_freezing_weights.py‎

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["oncall: distributed"]
22

3+
importcontextlib
34
importsys
45
fromenumimportEnum
56

@@ -31,28 +32,46 @@
3132

3233

3334
classModel(nn.Module):
34-
def__init__(self,with_fsdp,freeze_after_wrap_fsdp):
35+
def__init__(
36+
self,
37+
with_fsdp,
38+
freeze_after_wrap_fsdp,
39+
disable_autograd,
40+
fsdp_kwargs,
41+
):
3542
super().__init__()
3643
self.trunk=nn.Sequential(
3744
nn.Conv2d(3,64,kernel_size=3),
3845
nn.ReLU(inplace=True),
3946
nn.AdaptiveAvgPool2d(output_size=(1,1)),
4047
nn.Flatten(),
4148
)
49+
self.device=torch.cuda.current_device()
4250
self.head=nn.Linear(64,10)
4351
ifwith_fsdpandfreeze_after_wrap_fsdp:
44-
self.fsdp_wrap()
52+
self.fsdp_wrap(fsdp_kwargs)
53+
self.autograd_ctx= (
54+
torch.no_gradifdisable_autogradelsecontextlib.nullcontext
55+
)
4556

46-
deffsdp_wrap(self):
47-
self.trunk=FSDP(self.trunk)
48-
self.head=FSDP(self.head)
57+
deffsdp_wrap(self,fsdp_kwargs):
58+
self.trunk=FSDP(self.trunk,**fsdp_kwargs)
59+
self.head=FSDP(self.head,**fsdp_kwargs)
4960

5061
defforward(self,x):
51-
returnself.head(self.trunk(x))
62+
withself.autograd_ctx():
63+
x=self.trunk(x)
64+
returnself.head(x)
5265

5366

5467
classNestedTrunkModel(nn.Module):
55-
def__init__(self,with_fsdp,freeze_after_wrap_fsdp):
68+
def__init__(
69+
self,
70+
with_fsdp,
71+
freeze_after_wrap_fsdp,
72+
disable_autograd,
73+
fsdp_kwargs,
74+
):
5675
super().__init__()
5776
self.trunk=nn.Sequential(
5877
self._create_block(3,64,with_fsdp,freeze_after_wrap_fsdp),
@@ -64,17 +83,22 @@ def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
6483
nn.Linear(64,10),
6584
)
6685
ifwith_fsdpandfreeze_after_wrap_fsdp:
67-
self.fsdp_wrap()
86+
self.fsdp_wrap(fsdp_kwargs)
87+
self.autograd_ctx= (
88+
torch.no_gradifdisable_autogradelsecontextlib.nullcontext
89+
)
6890

69-
deffsdp_wrap(self):
91+
deffsdp_wrap(self,fsdp_kwargs):
7092
forname,childinself.trunk.named_children():
71-
wrapped_child=FSDP(child)
93+
wrapped_child=FSDP(child,**fsdp_kwargs)
7294
setattr(self.trunk,name,wrapped_child)
73-
self.trunk=FSDP(self.trunk)
74-
self.head=FSDP(self.head)
95+
self.trunk=FSDP(self.trunk,**fsdp_kwargs)
96+
self.head=FSDP(self.head,**fsdp_kwargs)
7597

7698
defforward(self,x):
77-
returnself.head(self.trunk(x))
99+
withself.autograd_ctx():
100+
x=self.trunk(x)
101+
returnself.head(x)
78102

79103
def_create_block(
80104
self,in_channels,out_channels,with_fsdp,freeze_after_wrap_fsdp
@@ -92,20 +116,53 @@ class FreezingMethod(str, Enum):
92116

93117

94118
classTestFreezingWeights(FSDPTest):
95-
def_create_model(self,with_fsdp,with_nested_trunk,freeze_after_wrap_fsdp):
119+
def_create_model(
120+
self,
121+
with_fsdp,
122+
with_nested_trunk,
123+
freeze_after_wrap_fsdp,
124+
disable_autograd,
125+
fsdp_kwargs,
126+
):
96127
ifwith_nested_trunk:
97-
model=NestedTrunkModel(with_fsdp,freeze_after_wrap_fsdp)
128+
model=NestedTrunkModel(
129+
with_fsdp,freeze_after_wrap_fsdp,disable_autograd,fsdp_kwargs
130+
)
98131
else:
99-
model=Model(with_fsdp,freeze_after_wrap_fsdp)
132+
model=Model(
133+
with_fsdp,freeze_after_wrap_fsdp,disable_autograd,fsdp_kwargs
134+
)
100135
returnmodel
101136

102137
def_dist_train(
103-
self,with_nested_trunk,freezing_method,freeze_after_wrap_fsdp,with_fsdp
138+
self,
139+
with_nested_trunk,
140+
freezing_method,
141+
freeze_after_wrap_fsdp,
142+
with_fsdp,
143+
disable_autograd,
144+
forward_prefetch,
104145
):
105146
torch.manual_seed(0)
106147
batch=torch.randn(size=(2,3,224,224)).cuda()
107148

108-
model=self._create_model(with_fsdp,with_nested_trunk,freeze_after_wrap_fsdp)
149+
fsdp_kwargs= {
150+
"device_id":self.rank,
151+
"forward_prefetch":forward_prefetch,
152+
}
153+
154+
ddp_kwargs= {
155+
"device_ids": [self.rank],
156+
"find_unused_parameters":Trueifdisable_autogradelseFalse,
157+
}
158+
159+
model=self._create_model(
160+
with_fsdp,
161+
with_nested_trunk,
162+
freeze_after_wrap_fsdp,
163+
disable_autograd,
164+
fsdp_kwargs,
165+
)
109166
model=model.cuda()
110167

111168
# freezing the trunk using requires_grad.
@@ -115,10 +172,10 @@ def _dist_train(
115172

116173
ifwith_fsdp:
117174
ifnotfreeze_after_wrap_fsdp:
118-
model.fsdp_wrap()
119-
model=FSDP(model)
175+
model.fsdp_wrap(fsdp_kwargs)
176+
model=FSDP(model,**fsdp_kwargs)
120177
else:
121-
model=DistributedDataParallel(model,device_ids=[self.rank])
178+
model=DistributedDataParallel(model,**ddp_kwargs)
122179

123180
target=torch.tensor([0,1],dtype=torch.long).cuda()
124181
criterion=nn.CrossEntropyLoss()
@@ -145,17 +202,34 @@ def _dist_train(
145202
"freezing_method", [FreezingMethod.RequiresGrad,FreezingMethod.GradToNone]
146203
)
147204
@parametrize("freeze_after_wrap_fsdp", [True,False])
205+
@parametrize("disable_autograd", [True,False])
206+
@parametrize("forward_prefetch", [True,False])
148207
deftest_freezing_weights(
149-
self,with_nested_trunk,freezing_method,freeze_after_wrap_fsdp
208+
self,
209+
with_nested_trunk,
210+
freezing_method,
211+
freeze_after_wrap_fsdp,
212+
disable_autograd,
213+
forward_prefetch,
150214
):
151215
# DDP
152216
ddp_state=self._dist_train(
153-
with_nested_trunk,freezing_method,freeze_after_wrap_fsdp,with_fsdp=False
217+
with_nested_trunk,
218+
freezing_method,
219+
freeze_after_wrap_fsdp,
220+
with_fsdp=False,
221+
disable_autograd=disable_autograd,
222+
forward_prefetch=False,# does not apply to DDP
154223
)
155224

156225
# FSDP
157226
fsdp_state=self._dist_train(
158-
with_nested_trunk,freezing_method,freeze_after_wrap_fsdp,with_fsdp=True
227+
with_nested_trunk,
228+
freezing_method,
229+
freeze_after_wrap_fsdp,
230+
with_fsdp=True,
231+
disable_autograd=disable_autograd,
232+
forward_prefetch=forward_prefetch,
159233
)
160234

161235
self.assertEqual(

‎torch/distributed/fsdp/_flat_param.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,7 @@ def _get_unflat_views_aligned(
18641864
returnviews
18651865

18661866
@no_type_check
1867+
@torch.enable_grad()
18671868
def_use_unsharded_views(self,as_params:bool)->None:
18681869
"""
18691870
Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
@@ -1874,6 +1875,12 @@ def _use_unsharded_views(self, as_params: bool) -> None:
18741875
the original parameters only as ``Tensor`` s. ``False`` should
18751876
be used during forward/backward computation and when hiding the
18761877
original parameters from :meth:`nn.Module.named_parameters`.
1878+
1879+
Note:
1880+
when prefetching for next forward, current forward may be
1881+
annotated with `@torch.no_grad()`
1882+
`@torch.enable_grad()` ensures non-empty `view.grad_fn`
1883+
otherwise `_post_backward_hook` will not get called
18771884
"""
18781885
flat_param=self.flat_param
18791886
self._check_unsharded(flat_param)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp