Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit3aee820

Browse files
yushangdipytorchmergebot
authored andcommitted
[dynamo] Preserve record function in aot_eager (#167787)
- When `torch._dynamo.config.capture_profiler_record_function = True`, dynamo will insert `torch.ops.profiler._record_function_enter_new.default, torch.ops.profiler._record_function_exit._RecordFunction` nodes in the graph- We preserve these profiler nodes in functionalization and proxy node creation (on by default)- To prevent problem with inductor fusion, we remove these nodes in pre_grad_passes in inductor (on by default)<img width="654" height="134" alt="Screenshot 2025-11-13 at 6 23 26 PM" src="https://github.com/user-attachments/assets/3d96fcda-88be-4f25-8cf4-592c2321f0e2" />```python test/dynamo/test_profiler.py -k test_dynamo_preserve_record_funcpython test/inductor/test_profiler.py -k test_inductor_remove_profiler_opspython test/dynamo/test_regional_inductor.py -k test_invoke_subgraph_inner_serialize_Falsepython test/test_autograd.py -k test_profiler_propaPYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py TestAutograd.test_record_function_legacy python test/dynamo/test_ctx_manager.py -k profiler```Pull Requestresolved:#167787Approved by:https://github.com/anijain2305,https://github.com/mlazos
1 parentcd75d6a commit3aee820

File tree

11 files changed

+404
-13
lines changed

11 files changed

+404
-13
lines changed

‎test/dynamo/test_profiler.py‎

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
importtorch._dynamo.testing
77
importtorch._dynamo.utils
88
fromtorch._dynamo.utilsimportdynamo_timed
9+
fromtorch.profilerimportrecord_function
910
fromtorch.testing._internal.common_utilsimportTemporaryFileName
1011

1112

@@ -220,6 +221,149 @@ def fn(x, y):
220221
],
221222
)
222223

224+
@torch._dynamo.config.patch("capture_profiler_record_function",True)
225+
deftest_dynamo_preserve_record_func(self):
226+
deffn(x):
227+
withrecord_function("my_net1"):
228+
a=x.sin()
229+
withrecord_function("my_cos"):
230+
b=a.cos()
231+
withrecord_function("my_net2"):
232+
c=b+2
233+
returnc
234+
235+
backend=torch._dynamo.testing.AotEagerAndRecordGraphs()
236+
fn_c=torch.compile(fn,backend=backend)
237+
fn_c(
238+
torch.randn(10),
239+
)
240+
self.assertExpectedInline(
241+
backend.graphs[0].code.strip(),
242+
"""\
243+
def forward(self, L_x_ : torch.Tensor):
244+
l_x_ = L_x_
245+
_record_function_enter_new = torch.ops.profiler._record_function_enter_new('my_net1', None)
246+
a = l_x_.sin(); l_x_ = None
247+
_record_function_exit__record_function = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new); _record_function_enter_new = _record_function_exit__record_function = None
248+
_record_function_enter_new_1 = torch.ops.profiler._record_function_enter_new('my_cos', None)
249+
b = a.cos(); a = None
250+
_record_function_exit__record_function_1 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_1); _record_function_enter_new_1 = _record_function_exit__record_function_1 = None
251+
_record_function_enter_new_2 = torch.ops.profiler._record_function_enter_new('my_net2', None)
252+
c = b + 2; b = None
253+
_record_function_exit__record_function_2 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_2); _record_function_enter_new_2 = _record_function_exit__record_function_2 = None
254+
return (c,)""",# noqa: B950
255+
)
256+
self.assertExpectedInline(
257+
backend.fw_graphs[0].code.strip(),
258+
"""\
259+
def forward(self, arg0_1):
260+
_record_function_enter_new = torch.ops.profiler._record_function_enter_new.default('my_net1')
261+
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
262+
_record_function_exit = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new); _record_function_enter_new = _record_function_exit = None
263+
_record_function_enter_new_1 = torch.ops.profiler._record_function_enter_new.default('my_cos')
264+
cos = torch.ops.aten.cos.default(sin); sin = None
265+
_record_function_exit_1 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_1); _record_function_enter_new_1 = _record_function_exit_1 = None
266+
_record_function_enter_new_2 = torch.ops.profiler._record_function_enter_new.default('my_net2')
267+
add = torch.ops.aten.add.Tensor(cos, 2); cos = None
268+
_record_function_exit_2 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_2); _record_function_enter_new_2 = _record_function_exit_2 = None
269+
return (add,)""",# noqa: B950
270+
)
271+
withtorch.profiler.profile()asprof:
272+
fn_c(
273+
torch.randn(10),
274+
)
275+
276+
annotations= [e.nameforeinprof.events()if"my_"ine.name]
277+
self.assertEqual(
278+
annotations,
279+
[
280+
"my_net1",
281+
"my_cos",
282+
"my_net2",
283+
],
284+
)
285+
286+
@torch._dynamo.config.patch("capture_profiler_record_function",True)
287+
deftest_dynamo_preserve_record_func_with_graph_break(self):
288+
# Test that record_function works correctly with graph breaks
289+
deffn(x):
290+
withrecord_function("pre_graph_break"):
291+
a=x.sin()
292+
# This causes a graph break
293+
torch._dynamo.graph_break()
294+
withrecord_function("post_graph_break"):
295+
b=a.cos()
296+
returnb
297+
298+
backend=torch._dynamo.testing.AotEagerAndRecordGraphs()
299+
fn_c=torch.compile(fn,backend=backend)
300+
fn_c(
301+
torch.randn(10),
302+
)
303+
304+
# We expect 2 graphs due to the graph break
305+
self.assertEqual(len(backend.graphs),2)
306+
307+
# First graph should have the pre_graph_break record_function
308+
self.assertIn("pre_graph_break",backend.graphs[0].code)
309+
self.assertIn("_record_function_enter_new",backend.graphs[0].code)
310+
self.assertIn("_record_function_exit",backend.graphs[0].code)
311+
312+
# Second graph should have the post_graph_break record_function
313+
self.assertIn("post_graph_break",backend.graphs[1].code)
314+
self.assertIn("_record_function_enter_new",backend.graphs[1].code)
315+
self.assertIn("_record_function_exit",backend.graphs[1].code)
316+
317+
# Verify profiler events work correctly
318+
withtorch.profiler.profile()asprof:
319+
fn_c(
320+
torch.randn(10),
321+
)
322+
323+
annotations= [
324+
e.name
325+
foreinprof.events()
326+
ife.namein ["pre_graph_break","post_graph_break"]
327+
]
328+
# Both record_function contexts should appear in profiler events
329+
self.assertEqual(
330+
annotations,
331+
[
332+
"pre_graph_break",
333+
"post_graph_break",
334+
],
335+
)
336+
337+
@torch._dynamo.config.patch("capture_profiler_record_function",True)
338+
deftest_dynamo_preserve_record_func_spanning_graph_break(self):
339+
# Test that record_function that spans across a graph break raises an error
340+
# This prevents the confusing behavior where the context gets duplicated across graphs
341+
deffn(x):
342+
x=x+1
343+
withrecord_function("spanning_context"):
344+
a=x.sin()
345+
torch._dynamo.graph_break()
346+
b=a.cos()
347+
b=b-1
348+
returnb
349+
350+
fn_c=torch.compile(fn,backend="aot_eager")
351+
x=torch.randn(10)
352+
fn_c(x)
353+
withtorch.profiler.profile()asprof:
354+
result=fn_c(x)
355+
356+
self.assertEqual(fn(x),result)
357+
358+
annotations= [e.nameforeinprof.events()ife.name=="spanning_context"]
359+
# record_function contexts should appear in profiler events once
360+
self.assertEqual(
361+
annotations,
362+
[
363+
"spanning_context",
364+
],
365+
)
366+
223367

224368
if__name__=="__main__":
225369
fromtorch._dynamo.test_caseimportrun_tests

‎test/inductor/test_profiler.py‎

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
importtorch._inductor.utils
1212
fromtorchimport_dynamoastorchdynamo
1313
fromtorch._inductorimportconfig
14-
fromtorch.profilerimportProfilerActivity
14+
fromtorch.profilerimportProfilerActivity,record_function
1515
fromtorch.testing._internal.common_utilsimportskipIfXpu,TemporaryFileName
1616
fromtorch.testing._internal.inductor_utilsimport (
1717
GPU_TYPE,
1818
HAS_GPU_AND_TRITON,
1919
IS_BIG_GPU,
2020
)
21+
fromtorch.testing._internal.logging_utilsimportlogs_to_string
2122
fromtorch.torch_versionimportTorchVersion
2223
fromtorch.utils._tritonimporthas_triton
2324

@@ -321,6 +322,39 @@ def fn(x, y):
321322
else:
322323
self.assertEqual("1",os.environ.get("DISABLE_CUPTI_LAZY_REINIT","0"))
323324

325+
@torch._dynamo.config.patch("capture_profiler_record_function",True)
326+
deftest_inductor_remove_profiler_ops(self):
327+
"""
328+
Test that inductor post_grad graph doesn't have profiler ops even when
329+
dynamo produce those ops.
330+
"""
331+
332+
deffn(x):
333+
withrecord_function("my_net1"):
334+
a=x.sin()
335+
withrecord_function("my_cos"):
336+
b=a.cos()
337+
withrecord_function("my_net2"):
338+
c=b+2
339+
returnc
340+
341+
log_stream,ctx=logs_to_string(
342+
"torch._inductor.compile_fx","post_grad_graphs"
343+
)
344+
withctx():
345+
torch.compile(fn,fullgraph=True)(torch.randn(10))
346+
aot_graphs="\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
347+
self.assertExpectedInline(
348+
aot_graphs,
349+
"""\
350+
sin: "f32[10][1]cpu" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
351+
cos: "f32[10][1]cpu" = torch.ops.aten.cos.default(sin); sin = None
352+
add: "f32[10][1]cpu" = torch.ops.aten.add.Tensor(cos, 2); cos = None
353+
return (add,)""",# noqa: B950
354+
ignore_comments=True,
355+
ignore_empty_lines=True,
356+
)
357+
324358

325359
if__name__=="__main__":
326360
fromtorch._inductor.test_caseimportrun_tests

‎torch/_dynamo/config.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,12 @@ def default_debug_dir_root() -> str:
577577
# This flag is ignored and maintained for backwards compatibility.
578578
capture_func_transforms=True
579579

580+
# Enable capturing torch.profiler.record_function ops in the graph
581+
# When True, profiler ops are emitted to the graph and preserved through
582+
# compilation (make_fx, functionalization). When False, profiler ops
583+
# are treated as nullcontext.
584+
capture_profiler_record_function:bool=False
585+
580586
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
581587
log_compilation_metrics=True
582588

‎torch/_dynamo/graph_break_registry.json‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3849,5 +3849,15 @@
38493849
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
38503850
]
38513851
}
3852+
],
3853+
"GB0370": [
3854+
{
3855+
"Gb_type":"record_function escaped from compiled region",
3856+
"Context":"str(self)",
3857+
"Explanation":"Dynamo doesn't support graph break inside record_function region.",
3858+
"Hints": [
3859+
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
3860+
]
3861+
}
38523862
]
38533863
}

‎torch/_dynamo/variables/ctx_manager.py‎

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
"""
2020

2121
importinspect
22+
importlogging
2223
importsys
2324
importwarnings
2425
fromcollections.abcimportCallable,Sequence,Sized
2526
fromcontextlibimportAbstractContextManager,ExitStack
2627
fromtypingimportAny,Optional,TYPE_CHECKING,Union
2728

2829
importtorch._C
30+
fromtorch._dynamoimportconfig
2931
fromtorch._guardsimportGuard
32+
fromtorch._loggingimportwarning_once
3033

3134
from ..importgraph_break_hints,variables
3235
from ..bytecode_transformationimport (
@@ -56,6 +59,8 @@
5659
fromtorch._dynamo.codegenimportPyCodegen
5760
fromtorch._dynamo.symbolic_convertimportInstructionTranslator
5861

62+
log=logging.getLogger(__name__)
63+
5964

6065
classContextWrappingVariable(VariableTracker):
6166
_nonvar_fields= {
@@ -1106,6 +1111,121 @@ def reconstruct(self, cg: "PyCodegen") -> None:
11061111
)
11071112

11081113

1114+
classProfilerRecordFunctionContextVariable(ContextWrappingVariable):
1115+
"""
1116+
This class represents torch profiler context objects.
1117+
1118+
For record_function: emits torch.ops.profiler._record_function_enter_new
1119+
to the graph on enter, and torch.ops.profiler._record_function_exit on exit.
1120+
But if emit_profiler_ops=False, behaves like nullcontext.
1121+
1122+
For profile: behaves like nullcontext, ignoring all side-effects.
1123+
"""
1124+
1125+
_nonvar_fields= {
1126+
"emit_profiler_ops",
1127+
*ContextWrappingVariable._nonvar_fields,
1128+
}
1129+
1130+
@staticmethod
1131+
defcreate(
1132+
func:Any,
1133+
record_args:Sequence[VariableTracker],
1134+
record_kwargs:"dict[str, VariableTracker]",
1135+
**kwargs:Any,
1136+
)->"ProfilerRecordFunctionContextVariable":
1137+
target_values=None
1138+
ifconfig.capture_profiler_record_function:
1139+
# Extract name and args for record_function
1140+
# record_function(name: str, args: Optional[str] = None)
1141+
name= (
1142+
record_args[0].as_python_constant()
1143+
ifrecord_args
1144+
elsekwargs.get(
1145+
"name",variables.ConstantVariable("unknown")
1146+
).as_python_constant()
1147+
)
1148+
record_args_const=None
1149+
iflen(record_args)>1:
1150+
record_args_const=record_args[1].as_python_constant()
1151+
elif"args"inkwargs:
1152+
record_args_const=kwargs["args"].as_python_constant()
1153+
target_values= [name,record_args_const]
1154+
else:
1155+
warning_once(log,"Profiler record function %s will be ignored",func)
1156+
returnProfilerRecordFunctionContextVariable(
1157+
target_values=target_values,
1158+
initial_values=None,
1159+
**kwargs,
1160+
)
1161+
1162+
def__init__(
1163+
self,
1164+
target_values:Any=None,
1165+
initial_values:Optional[Any]=None,
1166+
**kwargs:Any,
1167+
)->None:
1168+
super().__init__(
1169+
target_values=target_values,initial_values=initial_values,**kwargs
1170+
)
1171+
1172+
defenter(self,tx:"InstructionTranslator")->VariableTracker:
1173+
ifconfig.capture_profiler_record_function:
1174+
name,args=self.target_values
1175+
# Create the profiler entry node in the graph
1176+
self.proxy=tx.output.create_node(
1177+
"call_function",
1178+
torch.ops.profiler._record_function_enter_new,
1179+
(name,args),
1180+
{},
1181+
)
1182+
returnself
1183+
1184+
defexit(
1185+
self,tx:"InstructionTranslator",*args:VariableTracker
1186+
)->VariableTracker:
1187+
ifconfig.capture_profiler_record_function:
1188+
# Create the profiler exit node in the graph
1189+
tx.output.create_node(
1190+
"call_function",
1191+
torch.ops.profiler._record_function_exit._RecordFunction,
1192+
(self.proxy,),
1193+
{},
1194+
)
1195+
returnvariables.ConstantVariable.create(None)
1196+
1197+
defmodule_name(self)->str:
1198+
return (
1199+
"torch.autograd.profiler"
1200+
ifconfig.capture_profiler_record_function
1201+
else"contextlib"
1202+
)
1203+
1204+
deffn_name(self)->str:
1205+
return (
1206+
"record_function"
1207+
ifconfig.capture_profiler_record_function
1208+
else"nullcontext"
1209+
)
1210+
1211+
defreconstruct_type(self,codegen:"PyCodegen")->None:
1212+
# This will be called if we try to reconstruct the record_function type
1213+
# when there's a graph break. The _set_error_on_graph_break(True) in enter()
1214+
# will cause the graph break to raise an error before we get here.
1215+
ifconfig.capture_profiler_record_function:
1216+
unimplemented(
1217+
gb_type="record_function escaped from compiled region",
1218+
context=str(self),
1219+
explanation="Dynamo doesn't support graph break inside record_function region.",
1220+
hints=[
1221+
*graph_break_hints.SUPPORTABLE,
1222+
],
1223+
)
1224+
else:
1225+
# If capture is disabled, allow reconstruction (behaves like nullcontext)
1226+
super().reconstruct_type(codegen)
1227+
1228+
11091229
classPreserveVersionContextVariable(ContextWrappingVariable):
11101230
"""
11111231
Wraps torch.autograd._unsafe_preserve_version_counter

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp