Rate this Page

Common Graph Breaks#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Below are some common graph breaks and some workarounds.

Incorrect Code#

Your code might contain errors (meaning it doesn’t execute even withouttorch.compile). In the example below, there’s a typo in thetorch.sin call due to an extra argument.Always disabletorch.compile to check if the code runs correctly.

@torch.compiledeffn(x):y=torch.sin(x,x)returnytry:fn(torch.ones(3,3))exceptExceptionase:pass
Graph break in user code at /tmp/ipykernel_142/343837593.py:3Graph Break Reason: TypeError when making fake tensor call  Explanation:   Developer debug context: TypeError <built-in method sin of type object at 0x7f271f582260>: sin() takes 1 positional argument but 2 were given For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0112.htmlUser code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_142/343837593.py", line 7, in <module>    fn(torch.ones(3, 3))  File "/tmp/ipykernel_142/343837593.py", line 3, in fn    y = torch.sin(x, x)

Dynamo makes a best-effort attempt to hint if a graph break is caused by your code.But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code,is a more complicated graph break, or is atorch.compile bug. In order to differentiate, we recommend trying to run your code withouttorch.compile to see if you still get the error reported by the graph break.

Data-dependent operations#

torch.compile graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (.item,.data_ptr).

@torch.compiledeffn(x):y=x.sum()ify>0:returnx+y.item()returnx-y.item()print(fn(torch.ones(3,3)))
tensor([[10., 10., 10.],        [10., 10., 10.],        [10., 10., 10.]])
Graph break in user code at /tmp/ipykernel_142/3495555842.py:4Graph Break Reason: Data-dependent branching  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.  Hint: Use `torch.cond` to express dynamic control flow.  Developer debug context: attempted to jump with TensorVariable()User code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_142/3495555842.py", line 8, in <module>    print(fn(torch.ones(3, 3)))  File "/tmp/ipykernel_142/3495555842.py", line 4, in fn    if y > 0:Graph break from `Tensor.item()`, consider setting:    torch._dynamo.config.capture_scalar_outputs = Trueor:    env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1to include these operations in the captured graph.Graph break: from user code at:  File "/tmp/ipykernel_142/3495555842.py", line 5, in torch_dynamo_resume_in_fn_at_4    return x + y.item()Graph break in user code at /tmp/ipykernel_142/3495555842.py:5Graph Break Reason: Unsupported Tensor.item() call with capture_scalar_outputs=False  Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.  Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.  Developer debug context: call_method TensorVariable() item () {} For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.htmlUser code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_142/3495555842.py", line 8, in <module>    print(fn(torch.ones(3, 3)))  File "/tmp/ipykernel_142/3495555842.py", line 5, in fn    return x + y.item()

The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:

  • If your control flow doesn’t actually depend on data values, consider modifying your code to perform control flow on constants.

# oldx=torch.randn(3,3)@torch.compiledeffn(y):ifx.sum()>0:returny+xelse:returny-xprint(fn(torch.ones(3,3)))
tensor([[ 1.4090,  0.3393,  2.4898],        [-0.8570,  2.4646,  1.2680],        [ 0.8819,  1.5492,  2.2382]])
Graph break in user code at /tmp/ipykernel_142/2410325100.py:5Graph Break Reason: Data-dependent branching  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.  Hint: Use `torch.cond` to express dynamic control flow.  Developer debug context: attempted to jump with TensorVariable()User code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_142/2410325100.py", line 10, in <module>    print(fn(torch.ones(3, 3)))  File "/tmp/ipykernel_142/2410325100.py", line 5, in fn    if x.sum() > 0:
# newx=torch.randn(3,3)cond=(x.sum()>0).item()@torch.compiledeffn(y):ifcond:returny+xelse:returny-xprint(fn(torch.ones(3,3)))
tensor([[1.1334, 1.6834, 1.5118],        [1.3258, 1.0107, 2.2545],        [1.6032, 1.2765, 0.6417]])
# old@torch.compiledeffn(x):ifx.sum()>0:returnx+1returnx-1print(fn(torch.ones(3,3)))
tensor([[2., 2., 2.],        [2., 2., 2.],        [2., 2., 2.]])
Graph break in user code at /tmp/ipykernel_142/520574912.py:4Graph Break Reason: Data-dependent branching  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.  Hint: Use `torch.cond` to express dynamic control flow.  Developer debug context: attempted to jump with TensorVariable()User code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_142/520574912.py", line 8, in <module>    print(fn(torch.ones(3, 3)))  File "/tmp/ipykernel_142/520574912.py", line 4, in fn    if x.sum() > 0:
# new@torch.compiledeffn(x):returntorch.cond(x.sum()>0,lambdax:x+1,lambdax:x-1,(x,),)print(fn(torch.ones(3,3)))
tensor([[2., 2., 2.],        [2., 2., 2.],        [2., 2., 2.]])
  • If you have a.item() call, trytorch._dynamo.config.capture_scalar_outputs=TrueorTORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1.

  • Wrap problematic parts of the function in a custom operator

Printing and logging#

Printing/logging/issuing warnings will result in a graph break.You can try working around this by usingtorch._dynamo.config.reorderable_logging_functions.This config is used to reorder logging functions so that they are called at the end of thetraced function, thus avoiding a graph break.However, the logged contents may differ if, for example, a mutation occurs.

torch._dynamo.config.reorderable_logging_functions.add(print)@torch.compiledeffn(x):x+=1print("log!")returntorch.sin(x)print(fn(torch.ones(3,3)))
log!tensor([[0.9093, 0.9093, 0.9093],        [0.9093, 0.9093, 0.9093],        [0.9093, 0.9093, 0.9093]])