Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit422d151

Browse files
committed
[precompile] Support external data for serialization.
Under some corner cases, users don't directly pass us a free functionor module.forward. Instead, there're multiple levels of wrappers andcaptured closures applied to module.forward, so we end up serialzing quitea few captured data from the compiled function's closure. This is not greatwhen things like torch.nn.Module end up being captured in function's closure,and we shouldn't serialize the whole module in this case.To solve this issue, we specially handle the common cases when the followingobjects are being saved:1. torch.nn.Module2. Nested functions.When these data are being serialized, we will mark these as "external data"since they may contain unserialzable and untrackable states which isusually easier to be maintained from user side. So the a call to save_compiled_function()looks like the following:```result: AOTCompileSaveResult = compiled_fn.save_compiled_function("/path")```On the loading side, user are supposed to provide the external data dict. If any keyis missing, load_compiled_function will throw an error.```torch.compiler.load_compiled_function(f, external_data={...})```Note that it's not ideal for user to compile a typical function like this, so thisis only meant for power users for whom rewriting model is harder than maintainingthis data map.ghstack-source-id:a393674Pull Requestresolved:#170846
1 parent363e8e1 commit422d151

File tree

3 files changed

+179
-12
lines changed

3 files changed

+179
-12
lines changed

‎test/dynamo/test_aot_compile.py‎

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ def forward(self, x, d_x, mesh):
299299
returnx,y
300300

301301

302+
defwrap_forward_function(fn):
303+
@functools.wraps(fn,assigned=("__doc__","__annotations__","__type_params__"))
304+
defwrapped(*args,**kwargs):
305+
returnfn(*args,**kwargs)
306+
307+
returnwrapped
308+
309+
302310
@torch._dynamo.config.patch("enable_aot_compile",True)
303311
@instantiate_parametrized_tests
304312
classTestAOTCompile(torch._inductor.test_case.TestCase):
@@ -929,6 +937,73 @@ def test_aot_compile_with_redistribute(self):
929937
finally:
930938
torch.distributed.destroy_process_group()
931939

940+
deftest_aot_compile_with_captured_module(self):
941+
mod=SimpleLinearModule()
942+
943+
fn=mod.forward
944+
945+
defwith_processing(f,*args,**kwargs):
946+
returnf(*args,**kwargs)
947+
948+
fn=functools.partial(with_processing,fn)
949+
950+
fn=wrap_forward_function(fn)
951+
mod.forward=fn
952+
953+
compiled_fn=torch.compile(fn,fullgraph=True).aot_compile(
954+
((torch.randn(4,3),), {})
955+
)
956+
mod.forward=compiled_fn
957+
result=compiled_fn.save_compiled_function(self.path())
958+
values=list(result.external_data.values())
959+
self.assertIn(with_processing,values)
960+
self.assertIn(mod,values)
961+
withopen(self.path(),"rb")asf:
962+
withself.assertRaisesRegex(RuntimeError,"Missing required external ref"):
963+
torch.compiler.load_compiled_function(f)
964+
965+
withopen(self.path(),"rb")asf:
966+
compiled_fn=torch.compiler.load_compiled_function(
967+
f,external_data=result.external_data
968+
)
969+
test_inputs= (torch.randn(4,3),)
970+
expected=fn(*test_inputs)
971+
actual=compiled_fn(*test_inputs)
972+
self.assertEqual(expected,actual)
973+
974+
deftest_aot_compile_with_captured_module_2(self):
975+
mod=SimpleLinearModule()
976+
977+
fn=mod.forward
978+
979+
defwith_processing(f,*args,**kwargs):
980+
returnf(*args,**kwargs)
981+
982+
fn=functools.partial(with_processing,fn)
983+
984+
fn=wrap_forward_function(fn)
985+
986+
compiled_fn=torch.compile(fn,fullgraph=True).aot_compile(
987+
((torch.randn(4,3),), {})
988+
)
989+
mod.forward=compiled_fn
990+
result=compiled_fn.save_compiled_function(self.path())
991+
values=list(result.external_data.values())
992+
self.assertIn(with_processing,values)
993+
self.assertIn(mod,values)
994+
withopen(self.path(),"rb")asf:
995+
withself.assertRaisesRegex(RuntimeError,"Missing required external ref"):
996+
torch.compiler.load_compiled_function(f)
997+
998+
withopen(self.path(),"rb")asf:
999+
compiled_fn=torch.compiler.load_compiled_function(
1000+
f,external_data=result.external_data
1001+
)
1002+
test_inputs= (torch.randn(4,3),)
1003+
expected=fn(*test_inputs)
1004+
actual=compiled_fn(*test_inputs)
1005+
self.assertEqual(expected,actual)
1006+
9321007
deftest_aot_compile_with_checkpoint(self):
9331008
fromtorch.utils.checkpointimportcheckpoint
9341009

‎torch/_dynamo/aot_compile.py‎

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,35 @@ def check_compatibility(self) -> None:
6060

6161

6262
classAOTCompilePickler(pickle.Pickler):
63+
def__init__(self,*args:Any,**kwargs:Any)->None:
64+
super().__init__(*args,**kwargs)
65+
self.external_data= {}
66+
self.key_map= {}
67+
self.prefix_counters= {}
68+
69+
defgenerate_key(self,prefix:str)->str:
70+
ret=f"{prefix}:{self.prefix_counters.setdefault(prefix,0)}"
71+
self.prefix_counters[prefix]+=1
72+
returnret
73+
74+
defpersistent_id(self,obj):
75+
key=None
76+
ifid(obj)inself.key_map:
77+
returnself.key_map[id(obj)]
78+
79+
ifisinstance(obj,torch.nn.Module):
80+
ty=type(obj)
81+
key=self.generate_key(
82+
f"torch_nn_module:{ty.__module__}.{ty.__qualname__}"
83+
)
84+
elifinspect.isfunction(obj)andobj.__name__!=obj.__qualname__:
85+
key=self.generate_key(f"function:{obj.__qualname__}")
86+
87+
ifkeyisnotNone:
88+
self.external_data[key]=obj
89+
self.key_map[id(obj)]=key
90+
returnkey
91+
6392
@classmethod
6493
def_unpickle_cell(cls,val:Any)->Any:
6594
def_()->Any:
@@ -68,6 +97,10 @@ def _() -> Any:
6897
assert_.__closure__isnotNone
6998
return_.__closure__[0]
7099

100+
@classmethod
101+
def_unpickle_bound_method(cls,func:Any,base:Any)->Any:
102+
returntypes.MethodType(func,base)
103+
71104
@classmethod
72105
def_unpickle_module(cls,name:str)->Any:
73106
returnimportlib.import_module(name)
@@ -78,16 +111,53 @@ def reducer_override(self, obj: Any) -> Any:
78111
returntype(self)._unpickle_cell, (obj.cell_contents,)
79112
elifinspect.ismodule(obj):
80113
returntype(self)._unpickle_module, (obj.__name__,)
114+
elifinspect.ismethod(obj):
115+
"""
116+
By default, pickle will call getattr() directly on the self object
117+
for pickling bounded methods, this is not what we want, instead we
118+
always want to serialize the original function and the self object
119+
in their original form.
120+
"""
121+
func=obj.__func__
122+
method_self=obj.__self__
123+
inner_func=getattr(method_self,func.__name__)
124+
ifinspect.ismethod(inner_func):
125+
inner_func=inner_func.__func__
126+
iffuncisnotinner_func:
127+
returntype(self)._unpickle_bound_method, (func,method_self)
128+
81129
returnNotImplemented
82130

83131

132+
classAOTCompileUnpickler(pickle.Unpickler):
133+
def__init__(self,external_data:dict[str,Any],file:io.BytesIO):
134+
super().__init__(file)
135+
self.external_data=external_data
136+
137+
defpersistent_load(self,key:str):
138+
ifkeynotinself.external_data:
139+
raiseRuntimeError(
140+
f"Missing required external reference to data:{key}. "
141+
"Please load AOT compiled function with "
142+
"`external_data=<external data dictionary>`"
143+
f"{self.external_data}"
144+
)
145+
returnself.external_data[key]
146+
147+
148+
@dataclass
149+
classAOTCompileSaveResult:
150+
serialized_data:bytes
151+
external_data:dict[str,Any]
152+
153+
84154
@dataclass
85155
classAOTCompiledFunction:
86156
_artifacts:CompileArtifacts
87157
_guard_check_enabled:bool=True
88158
_extra_globals:Optional[dict[str,object]]=None
89159

90-
defguard_check(self,*args:Any,**kwargs:Any)->bool:
160+
defprepare_f_locals(self,*args,**kwargs):
91161
f_locals:dict[str,Any]= {}
92162
env=self._artifacts.runtime_env
93163
ifenv.closure:
@@ -99,6 +169,10 @@ def guard_check(self, *args: Any, **kwargs: Any) -> bool:
99169
forname,cellinzip(env.bytecode.co_freevars,env.closure)
100170
}
101171
f_locals.update(bind_locals(self._artifacts.signature,*args,**kwargs))
172+
returnf_locals
173+
174+
defguard_check(self,*args:Any,**kwargs:Any)->bool:
175+
f_locals=self.prepare_f_locals(*args,**kwargs)
102176
assertself._artifacts.guard_managerisnotNone
103177
returnself._artifacts.guard_manager.check(f_locals)
104178

@@ -125,20 +199,22 @@ def __post_init__(self) -> None:
125199
def__call__(self,*args:Any,**kwargs:Any)->Any:
126200
assertself._artifacts.guard_managerisnotNone
127201
ifself._guard_check_enabledandnotself.guard_check(*args,**kwargs):
128-
f_locals=bind_locals(self._artifacts.signature,*args,**kwargs)
202+
f_locals=self.prepare_f_locals(*args,**kwargs)
129203
reason=str(self._artifacts.guard_manager.check_verbose(f_locals))
130204
raiseRuntimeError(f"GuardManager check failed, reason:{reason}")
131205
returnself.fn(*args,**kwargs)
132206

133207
defsource_info(self)->"SourceInfo":
134208
returnself._artifacts.source_info
135209

136-
defsave_compiled_function(self,path:str)->None:
210+
defsave_compiled_function(self,path:str)->AOTCompileSaveResult:
137211
withopen(path,"wb")asf:
138-
f.write(type(self).serialize(self))
212+
result=type(self).serialize(self)
213+
f.write(result.serialized_data)
214+
returnresult
139215

140216
@classmethod
141-
defserialize(cls,fn:"AOTCompiledFunction")->bytes:
217+
defserialize(cls,fn:"AOTCompiledFunction")->AOTCompileSaveResult:
142218
fromtorch._dynamo.packageimportSerializedCode
143219

144220
state=fn._artifacts.__dict__.copy()
@@ -156,15 +232,25 @@ def serialize(cls, fn: "AOTCompiledFunction") -> bytes:
156232
buf=io.BytesIO()
157233
pickler=AOTCompilePickler(buf)
158234
pickler.dump(state)
159-
returnbuf.getvalue()
235+
returnAOTCompileSaveResult(
236+
serialized_data=buf.getvalue(),
237+
external_data=pickler.external_data,
238+
)
160239

161240
@classmethod
162241
defdeserialize(
163-
cls,data:bytes,f_globals:Optional[dict[str,object]]=None
242+
cls,
243+
data:bytes,
244+
f_globals:dict[str,object]|None=None,
245+
external_closure_data:dict[str,Any]|None=None,
164246
)->"AOTCompiledFunction":
165247
fromtorch._dynamo.packageimportSerializedCode
166248

167-
state=pickle.loads(data)
249+
f=io.BytesIO(data)
250+
f.seek(0)
251+
unpickler=AOTCompileUnpickler(external_closure_dataor {},f)
252+
state=unpickler.load()
253+
f.close()
168254
state["runtime_env"]=dataclasses.replace(
169255
state["runtime_env"],
170256
bytecode=SerializedCode.to_code_object(state["runtime_env"].bytecode),
@@ -339,7 +425,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
339425
defserialize(self)->bytes:
340426
data:list[bytes]= []
341427
forresultinself.compiled_results:
342-
data.append(AOTCompiledFunction.serialize(result))
428+
data.append(AOTCompiledFunction.serialize(result).serialized_data)
343429
returnpickle.dumps(data)
344430

345431
@classmethod

‎torch/compiler/__init__.py‎

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,10 @@ def nested_compile_region(
712712

713713

714714
defload_compiled_function(
715-
file:io.IOBase,*,f_globals:Optional[dict[str,object]]=None
715+
file:io.IOBase,
716+
*,
717+
f_globals:dict[str,object]|None=None,
718+
external_data:dict[str,Any]|None=None,
716719
)->Callable[...,Any]:
717720
"""
718721
Load an aot-compiled function from a file.
@@ -723,12 +726,15 @@ def load_compiled_function(
723726
724727
Args:
725728
file: A file-like object containing the serialized compiled function.
726-
f_globals: Optional globals to be loaded into the compiled function.
729+
f_globals: Optional global scope enclosing the compiled function.
730+
external_data: Optional data to be loaded into the runtime environment
731+
of the compiled function. This should contains the same
732+
data as AOTCompileResult.external_data returned from save_compiled_function() call.
727733
728734
Returns:
729735
A torch-compiled function with compilation preloaded from disk.
730736
"""
731737
fromtorch._dynamo.aot_compileimportAOTCompiledFunction
732738

733739
data=file.read()
734-
returnAOTCompiledFunction.deserialize(data,f_globals)
740+
returnAOTCompiledFunction.deserialize(data,f_globals,external_data)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp