Nested Graph Breaks#
Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025
Summary:
Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
A nested graph break results in duplicate graph break behavior
Recall that whentorch.compile is applied to a function, any nested function calls are also traced.Anested graph break refers to any graph break that happens in a nested function call.
definner(x):...torch._dynamo.graph_break()# nested graph break...@torch.compiledefouter(x):...y=inner(x)...
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
Recall that infullgraph=False,graph breaks are handled by compiling the FX graph that has been determined so far,running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
We can therefore resume tracing after a nested graph break with this restriction in the following way:
First, consider the below example wheretorch.compile traces fromf and traces all the way until thegraph break ininner1 is encountered.
definner1(x):x=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2definner2(x):x=x+4x=inner1(x)x=x+8@torch.compiledeff(x):# start tracing from herex=x+16x=inner2(x)x=x+32f(torch.randn(3))
Since we can only resume from top-level functions, we graph break on theinner2 call inf.
# The semantics of torch.compile(f)(x) is roughly this:defcompiled_f_semantics(x):y=x+16z=inner2(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32compiled_f_semantics(torch.randn(3))
inner2 is then automatically compiled as a top-level function.We trace all the way until the graph break ininner1 is encountered again.
definner1(x):x=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2# this torch.compile is automatically applied@torch.compiledefinner2(x):# start tracing from herex=x+4x=inner1(x)x=x+8defcompiled_f_semantics(x):y=x+16z=inner2(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32compiled_f_semantics(torch.randn(3))
Then we graph break on theinner1 call ininner2.
defcompiled_inner2_semantics(x):y=x+4z=inner1(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8
inner1 is then automatically compiled as a top-level function.The graph break is frominner1, so we handle the graph break normally.
# this torch.compile is automatically applied@torch.compiledefinner1(x):# start tracing from herex=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2defcompiled_f_semantics(x):y=x+16z=compiled_inner2_semantics(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32defcompiled_inner2_semantics(x):y=x+4z=inner1(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8compiled_f_semantics(torch.randn(3))
inner1 is handled normally:
defcompiled_inner1_semantics(x):y=x+1torch._dynamo.graph_break()returntorch.compile(resume_inner1_semantics)(y)defresume_inner1_semantics(x):returnx+2
So the initial code is semantically equivalent to
defcompiled_f_semantics(x):y=x+16z=compiled_inner2_semantics(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32defcompiled_inner2_semantics(x):y=x+4z=compiled_inner1_semantics(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8defcompiled_inner1_semantics(x):y=x+1torch._dynamo.graph_break()returntorch.compile(resume_inner1_semantics)(y)defresume_inner1_semantics(x):returnx+2compiled_f_semantics(torch.randn(3))
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.This explains why you may encounter duplicate graph breaks when usingtorch.compile.
In summary, nested graph breaks are handled by:
Tracing from the top-level function all the way to the nested graph break
Graph breaking on the top-level function at the call to the second-level function
Compiling the PyTorch ops tracked so far and running the compiled graph
Calling the second-level function, which gets automatically compiled as a top-level function
Resuming tracing after the second-level function call
Note that the runtime of handling this graph break is, where is the nesting depth,and is the number of instructions from the top-level function to the graph break.We end up tracing frames, and we trace the same graph break times.