11# Owner(s): ["oncall: distributed"]
22
3+ import contextlib
34import sys
45from enum import Enum
56
3132
3233
3334class Model (nn .Module ):
34- def __init__ (self ,with_fsdp ,freeze_after_wrap_fsdp ):
35+ def __init__ (
36+ self ,
37+ with_fsdp ,
38+ freeze_after_wrap_fsdp ,
39+ disable_autograd ,
40+ fsdp_kwargs ,
41+ ):
3542super ().__init__ ()
3643self .trunk = nn .Sequential (
3744nn .Conv2d (3 ,64 ,kernel_size = 3 ),
3845nn .ReLU (inplace = True ),
3946nn .AdaptiveAvgPool2d (output_size = (1 ,1 )),
4047nn .Flatten (),
4148 )
49+ self .device = torch .cuda .current_device ()
4250self .head = nn .Linear (64 ,10 )
4351if with_fsdp and freeze_after_wrap_fsdp :
44- self .fsdp_wrap ()
52+ self .fsdp_wrap (fsdp_kwargs )
53+ self .autograd_ctx = (
54+ torch .no_grad if disable_autograd else contextlib .nullcontext
55+ )
4556
46- def fsdp_wrap (self ):
47- self .trunk = FSDP (self .trunk )
48- self .head = FSDP (self .head )
57+ def fsdp_wrap (self , fsdp_kwargs ):
58+ self .trunk = FSDP (self .trunk , ** fsdp_kwargs )
59+ self .head = FSDP (self .head , ** fsdp_kwargs )
4960
5061def forward (self ,x ):
51- return self .head (self .trunk (x ))
62+ with self .autograd_ctx ():
63+ x = self .trunk (x )
64+ return self .head (x )
5265
5366
5467class NestedTrunkModel (nn .Module ):
55- def __init__ (self ,with_fsdp ,freeze_after_wrap_fsdp ):
68+ def __init__ (
69+ self ,
70+ with_fsdp ,
71+ freeze_after_wrap_fsdp ,
72+ disable_autograd ,
73+ fsdp_kwargs ,
74+ ):
5675super ().__init__ ()
5776self .trunk = nn .Sequential (
5877self ._create_block (3 ,64 ,with_fsdp ,freeze_after_wrap_fsdp ),
@@ -64,17 +83,22 @@ def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
6483nn .Linear (64 ,10 ),
6584 )
6685if with_fsdp and freeze_after_wrap_fsdp :
67- self .fsdp_wrap ()
86+ self .fsdp_wrap (fsdp_kwargs )
87+ self .autograd_ctx = (
88+ torch .no_grad if disable_autograd else contextlib .nullcontext
89+ )
6890
69- def fsdp_wrap (self ):
91+ def fsdp_wrap (self , fsdp_kwargs ):
7092for name ,child in self .trunk .named_children ():
71- wrapped_child = FSDP (child )
93+ wrapped_child = FSDP (child , ** fsdp_kwargs )
7294setattr (self .trunk ,name ,wrapped_child )
73- self .trunk = FSDP (self .trunk )
74- self .head = FSDP (self .head )
95+ self .trunk = FSDP (self .trunk , ** fsdp_kwargs )
96+ self .head = FSDP (self .head , ** fsdp_kwargs )
7597
7698def forward (self ,x ):
77- return self .head (self .trunk (x ))
99+ with self .autograd_ctx ():
100+ x = self .trunk (x )
101+ return self .head (x )
78102
79103def _create_block (
80104self ,in_channels ,out_channels ,with_fsdp ,freeze_after_wrap_fsdp
@@ -92,20 +116,53 @@ class FreezingMethod(str, Enum):
92116
93117
94118class TestFreezingWeights (FSDPTest ):
95- def _create_model (self ,with_fsdp ,with_nested_trunk ,freeze_after_wrap_fsdp ):
119+ def _create_model (
120+ self ,
121+ with_fsdp ,
122+ with_nested_trunk ,
123+ freeze_after_wrap_fsdp ,
124+ disable_autograd ,
125+ fsdp_kwargs ,
126+ ):
96127if with_nested_trunk :
97- model = NestedTrunkModel (with_fsdp ,freeze_after_wrap_fsdp )
128+ model = NestedTrunkModel (
129+ with_fsdp ,freeze_after_wrap_fsdp ,disable_autograd ,fsdp_kwargs
130+ )
98131else :
99- model = Model (with_fsdp ,freeze_after_wrap_fsdp )
132+ model = Model (
133+ with_fsdp ,freeze_after_wrap_fsdp ,disable_autograd ,fsdp_kwargs
134+ )
100135return model
101136
102137def _dist_train (
103- self ,with_nested_trunk ,freezing_method ,freeze_after_wrap_fsdp ,with_fsdp
138+ self ,
139+ with_nested_trunk ,
140+ freezing_method ,
141+ freeze_after_wrap_fsdp ,
142+ with_fsdp ,
143+ disable_autograd ,
144+ forward_prefetch ,
104145 ):
105146torch .manual_seed (0 )
106147batch = torch .randn (size = (2 ,3 ,224 ,224 )).cuda ()
107148
108- model = self ._create_model (with_fsdp ,with_nested_trunk ,freeze_after_wrap_fsdp )
149+ fsdp_kwargs = {
150+ "device_id" :self .rank ,
151+ "forward_prefetch" :forward_prefetch ,
152+ }
153+
154+ ddp_kwargs = {
155+ "device_ids" : [self .rank ],
156+ "find_unused_parameters" :True if disable_autograd else False ,
157+ }
158+
159+ model = self ._create_model (
160+ with_fsdp ,
161+ with_nested_trunk ,
162+ freeze_after_wrap_fsdp ,
163+ disable_autograd ,
164+ fsdp_kwargs ,
165+ )
109166model = model .cuda ()
110167
111168# freezing the trunk using requires_grad.
@@ -115,10 +172,10 @@ def _dist_train(
115172
116173if with_fsdp :
117174if not freeze_after_wrap_fsdp :
118- model .fsdp_wrap ()
119- model = FSDP (model )
175+ model .fsdp_wrap (fsdp_kwargs )
176+ model = FSDP (model , ** fsdp_kwargs )
120177else :
121- model = DistributedDataParallel (model ,device_ids = [ self . rank ] )
178+ model = DistributedDataParallel (model ,** ddp_kwargs )
122179
123180target = torch .tensor ([0 ,1 ],dtype = torch .long ).cuda ()
124181criterion = nn .CrossEntropyLoss ()
@@ -145,17 +202,34 @@ def _dist_train(
145202"freezing_method" , [FreezingMethod .RequiresGrad ,FreezingMethod .GradToNone ]
146203 )
147204@parametrize ("freeze_after_wrap_fsdp" , [True ,False ])
205+ @parametrize ("disable_autograd" , [True ,False ])
206+ @parametrize ("forward_prefetch" , [True ,False ])
148207def test_freezing_weights (
149- self ,with_nested_trunk ,freezing_method ,freeze_after_wrap_fsdp
208+ self ,
209+ with_nested_trunk ,
210+ freezing_method ,
211+ freeze_after_wrap_fsdp ,
212+ disable_autograd ,
213+ forward_prefetch ,
150214 ):
151215# DDP
152216ddp_state = self ._dist_train (
153- with_nested_trunk ,freezing_method ,freeze_after_wrap_fsdp ,with_fsdp = False
217+ with_nested_trunk ,
218+ freezing_method ,
219+ freeze_after_wrap_fsdp ,
220+ with_fsdp = False ,
221+ disable_autograd = disable_autograd ,
222+ forward_prefetch = False ,# does not apply to DDP
154223 )
155224
156225# FSDP
157226fsdp_state = self ._dist_train (
158- with_nested_trunk ,freezing_method ,freeze_after_wrap_fsdp ,with_fsdp = True
227+ with_nested_trunk ,
228+ freezing_method ,
229+ freeze_after_wrap_fsdp ,
230+ with_fsdp = True ,
231+ disable_autograd = disable_autograd ,
232+ forward_prefetch = forward_prefetch ,
159233 )
160234
161235self .assertEqual (