Rate this Page

Dynamo Core Concepts#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Dynamo,torch.compile’s frontend, performstracing to capture the semantics of a Python function(and its nested function calls) into a linear sequence of operations (the “(FX) graph”),residual bytecode, and “guards” (a list of conditions under which the graph and bytecode are valid).

  • Unsupported Python features lead tograph breaks, where Dynamo compiles a partial graph acquired from tracing,then runs the unsupported code, then resumes tracing.

  • Graph breaks may lead to slowness in torch.compile and prevent backend optimization opportunities.If you’re not seeing the performance you expect, then check for graph breaks.

Dynamo Tracing#

torch.compile’s frontend (Dynamo) is a custom Python bytecode interpreter designed to allow graph compilationin PyTorch programs while retaining the full flexibility of Python. Given a function to be compiled, Dynamointerprets Python bytecode to extract sequences of PyTorch operations into 1 or more FX graphs that may be further optimized by a backend.

Summary diagram of Dynamo

For example, for the functionf in the above diagram, Dynamo produces:

  • a singleFX graph that takes in the original input plus some additional inputs required by the function.

  • Python bytecode that can be used as a drop-in replacement forf. In our example, the bytecode retrievesthe additional inputs and passes it to the graph and also contains unoptimizable Python side effects (the list append)

  • guards that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,the graph produced by Dynamo specializes on the shapes of input Tensors.

Graph Breaks#

Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorchoperators (FX graph). However, this is not always possible. When encountering code that can’t be traced, a “graph break” occurs.In the defaulttorch.compile settings, a graph break involves compiling the FX graph that has been determined so far,running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.

Graph breaks are a feature that allows Dynamo to run over arbitrary Python code and carve out functional subgraphs that can each be individually optimized.

However, it is possible for graph breaks to lead to unexpected slowness intorch.compile.If you’re not getting the speedups you expect, we recommend checking for graph breaks and removing them.

Graph breaks may occur on things like:

  • Data-dependent if-statements

  • Many Python built-in functions

  • C functions

Below is an example of a graph break due to calling an unsupported operationtorch.save:

@torch.compiledeff(x):y=x**2/2torch.save(y,"foo.pt")# torch.save is an unsupported operationz=y**3/6returnzx=torch.randn(3)print(f(x))
tensor([6.3085e-03, 8.2592e-01, 5.1903e-08])
Graph break in user code at /tmp/ipykernel_265/215272159.py:4Graph Break Reason: Attempted to call function marked as skipped  Explanation: Dynamo developers have intentionally marked that the function `save` in file `/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/serialization.py` should not be traced.  Hint: Avoid calling the function `save`.  Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `save` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.  Hint: Please file an issue to PyTorch.  Developer debug context: module: torch.serialization, qualname: save, skip reason: <missing reason> For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.htmlUser code traceback:  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main    return _run_code(code, main_globals, None,  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code    exec(code, run_globals)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>    app.launch_new_instance()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance    app.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 758, in start    self.io_loop.start()  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start    self.asyncio_loop.run_forever()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever    self._run_once()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once    handle._run()  File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run    self._context.run(self._callback, *self._args)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/utils.py", line 71, in preserve_context    return await f(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 614, in shell_main    await self.dispatch_shell(msg, subshell_id=subshell_id)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_shell    await result  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 366, in execute_request    await super().execute_request(stream, ident, parent)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 827, in execute_request    reply_content = await reply_content  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 458, in do_execute    res = shell.run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 663, in run_cell    return super().run_cell(*args, **kwargs)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell    result = self._run_cell(  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell    result = runner(coro)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner    coro.send(None)  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes    if await self.run_code(code, result, async_=asy):  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code    exec(code_obj, self.user_global_ns, self.user_ns)  File "/tmp/ipykernel_265/215272159.py", line 9, in <module>    print(f(x))  File "/tmp/ipykernel_265/215272159.py", line 4, in f    torch.save(y, "foo.pt")  # torch.save is an unsupported operation

The semantics oftorch.compile(f)(x) are roughly this:

defcompiled_f_semantics(x):y=torch.compile(g,fullgraph=True)(x)torch.save(y,"foo.pt")z=torch.compile(h,fullgraph=True)(x)returnzdefg(x):returnx**2/2defh(x):returny**3/6

Guards#

torch.compile makes some assumptions about runtime values as we trace through code. During tracing, we generate “guards”,which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if wecan reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.

Below is an example of generated guards. TheTENSOR_MATCH guard checks for the input’s type, device, dtype, shape, etc.

@torch.compiledeffn(x):returnx+1print(fn(torch.ones(3,3)))
tensor([[2., 2., 2.],        [2., 2., 2.],        [2., 2., 2.]])
GUARDS:TREE_GUARD_MANAGER:+- RootGuardManager| +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None  # _dynamo/output_graph.py:688 in init_ambient_guards| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:676 in init_ambient_guards| +- GLOBAL_STATE: ___check_global_state()| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()| +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False)| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # mp/ipykernel_265/1068332425.py:3 in fn| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # mp/ipykernel_265/1068332425.py:3 in fnGuard eval latency = 155.78 us

Recompilations#

If the guards fail for every instance of previously compiled code, thentorch.compile must “recompile” the function,requiring the original code to be traced again. In the example below, recompilation is necessary because the guard checking the tensor argument’s shape failed.

@torch.compiledeffn(x):returnx+1print(fn(torch.ones(3,3)))print(fn(torch.ones(4,4)))
tensor([[2., 2., 2.],        [2., 2., 2.],        [2., 2., 2.]])tensor([[2., 2., 2., 2.],        [2., 2., 2., 2.],        [2., 2., 2., 2.],        [2., 2., 2., 2.]])
Recompiling function fn in /tmp/ipykernel_265/420870727.py:1    triggered by the following guard failure(s):    - 3/0: tensor 'x' size mismatch at index 0. expected 3, actual 4

Dynamic Shapes#

torch.compile initially assumes tensor shapes are static/constant and guards based on these assumptions. By using “dynamic shapes,”we can gettorch.compile to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ.By default, automatic dynamic shapes are enabled intorch.compile(dynamic=None) - if compilation fails due to shape mismatch,recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled (dynamic=True) or disabled (dynamic=False).

Below, we enable dynamic shapes and note that we no longer need to recompile.

@torch.compile(dynamic=True)deffn(x):returnx+1print(fn(torch.ones(3,3)))print(fn(torch.ones(4,4)))
tensor([[2., 2., 2.],        [2., 2., 2.],        [2., 2., 2.]])tensor([[2., 2., 2., 2.],        [2., 2., 2., 2.],        [2., 2., 2., 2.],        [2., 2., 2., 2.]])
create_envcreate_symbol s77 = 3 for L['x'].size()[0] [2, int_oo] return x + 1  # mp/ipykernel_265/1458103805.py:3 in fn (_dynamo/variables/builder.py:3508 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"create_symbol s77 duck sized L['x'].size()[1]eval False == False [statically known]eval False == False [statically known]produce_guardstrack_symint L['x'].size()[0] s77 Nonetrack_symint L['x'].size()[1] s77 Nonetrack_symint L['x'].stride()[0] s77 Nonetrack_symint L['x'].stride()[1] 1 Nonetrack_symint L['x'].storage_offset() 0 NoneSkipping guard L['x'].stride()[1] == 1Skipping guard L['x'].storage_offset() == 0

For more information on dynamic shapes, seeThe dynamic shapes manual.