Rate this Page

Nested Graph Breaks#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below

  • A nested graph break results inO(N)\mathcal O(N) duplicate graph break behavior

Recall that whentorch.compile is applied to a function, any nested function calls are also traced.Anested graph break refers to any graph break that happens in a nested function call.

definner(x):...torch._dynamo.graph_break()# nested graph break...@torch.compiledefouter(x):...y=inner(x)...

The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.

Recall that infullgraph=False,graph breaks are handled by compiling the FX graph that has been determined so far,running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.

We can therefore resume tracing after a nested graph break with this restriction in the following way:

First, consider the below example wheretorch.compile traces fromf and traces all the way until thegraph break ininner1 is encountered.

definner1(x):x=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2definner2(x):x=x+4x=inner1(x)x=x+8@torch.compiledeff(x):# start tracing from herex=x+16x=inner2(x)x=x+32f(torch.randn(3))

Since we can only resume from top-level functions, we graph break on theinner2 call inf.

# The semantics of torch.compile(f)(x) is roughly this:defcompiled_f_semantics(x):y=x+16z=inner2(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32compiled_f_semantics(torch.randn(3))

inner2 is then automatically compiled as a top-level function.We trace all the way until the graph break ininner1 is encountered again.

definner1(x):x=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2# this torch.compile is automatically applied@torch.compiledefinner2(x):# start tracing from herex=x+4x=inner1(x)x=x+8defcompiled_f_semantics(x):y=x+16z=inner2(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32compiled_f_semantics(torch.randn(3))

Then we graph break on theinner1 call ininner2.

defcompiled_inner2_semantics(x):y=x+4z=inner1(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8

inner1 is then automatically compiled as a top-level function.The graph break is frominner1, so we handle the graph break normally.

# this torch.compile is automatically applied@torch.compiledefinner1(x):# start tracing from herex=x+1torch._dynamo.graph_break()# stop tracing due to graph breakreturnx+2defcompiled_f_semantics(x):y=x+16z=compiled_inner2_semantics(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32defcompiled_inner2_semantics(x):y=x+4z=inner1(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8compiled_f_semantics(torch.randn(3))

inner1 is handled normally:

defcompiled_inner1_semantics(x):y=x+1torch._dynamo.graph_break()returntorch.compile(resume_inner1_semantics)(y)defresume_inner1_semantics(x):returnx+2

So the initial code is semantically equivalent to

defcompiled_f_semantics(x):y=x+16z=compiled_inner2_semantics(y)returntorch.compile(resume_f_semantics)(z)defresume_f_semantics(x):returnx+32defcompiled_inner2_semantics(x):y=x+4z=compiled_inner1_semantics(y)returntorch.compile(resume_inner2_semantics)(z)defresume_inner2_semantics(x):returnx+8defcompiled_inner1_semantics(x):y=x+1torch._dynamo.graph_break()returntorch.compile(resume_inner1_semantics)(y)defresume_inner1_semantics(x):returnx+2compiled_f_semantics(torch.randn(3))

Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.This explains why you may encounter duplicate graph breaks when usingtorch.compile.

In summary, nested graph breaks are handled by:

  • Tracing from the top-level function all the way to the nested graph break

  • Graph breaking on the top-level function at the call to the second-level function

  • Compiling the PyTorch ops tracked so far and running the compiled graph

  • Calling the second-level function, which gets automatically compiled as a top-level function

  • Resuming tracing after the second-level function call

Note that the runtime of handling this graph break isO(NK)\mathcal O(NK), whereNN is the nesting depth,andKK is the number of instructions from the top-level function to the graph break.We end up tracingO(N2)\mathcal O(N^2) frames, and we trace the same graph breakO(N)\mathcal O(N) times.