Rate this Page

Usetorch._dynamo.nonstrict_trace#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Usenonstrict_trace to trace a function with non-strict tracing inside of atorch.compile’d region.You may wish to do this because the Dynamo graph breaks on something inside of the functionand you are sure that the function is non-strict traceable.

Consider the following scenario:

defget_magic_num():# This explicit graph break call is meant to emulate any kind of Dynamo# graph break, e.g., the function is implemented in C, or uses some python# language feature Dynamo doesn't yet support.torch._dynamo.graph_break()returntorch.tensor([42])@torch.compile(fullgraph=True)deffunc(x):n=get_magic_num()returnx+ntry:func(torch.rand(10))exceptExceptionase:print(e)
Call to `torch._dynamo.graph_break()`  Explanation: User-inserted graph break. Message: None  Hint: Remove the `torch._dynamo.graph_break()` call.  Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.htmlfrom user code:   File "/tmp/ipykernel_316/2253748958.py", line 9, in func    n = get_magic_num()  File "/tmp/ipykernel_316/2253748958.py", line 5, in get_magic_num    torch._dynamo.graph_break()Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

If we run the code above, we’ll get an error from Dynamo, because it sees a graph break while the user specifiedfullgraph=True.

In these situations, if a user still wants to keepfullgraph=True, they typically have several options:

  1. The graph break is due to a language feature Dynamo doesn’t yet support.In this case, the user either rewrites their code, or files an issue on GitHub.

  2. The graph break is due to a call to a function implemented in C.In this case, the user can try to use a custom op.The user could also try providing a polyfill (a reference implementation in Python)so that Dynamo can trace through it.

  3. Worst case scenario – an internal compiler error. In this case, the user likely has to file an issue on GitHub.

In addition to all these options, PyTorch does provide an alternativetorch._dynamo.nonstrict_trace, if the function call that induced the graph break satisfies certain requirements:

  • The requirements ofgeneral non-strict tracing.

  • The inputs and outputs must contain either basic types (e.g.,int,float,list,dict,torch.Tensor),or user-defined types that are registered totorch.utils._pytree.

  • The function must be defined outside thetorch.compile’d region.

  • Any non-input values read by the function will be treated as a constant(e.g., a global tensor), and will not be guarded on.

When tracing through a call to atorch._dynamo.nonstrict_trace’d function,torch.compile switches tonon-strict tracing,and the FX graph will eventually contain all the relevant tensor operations which happened inside that function.

For the example above, we can usetorch._dynamo.nonstrict_tracetoeliminate the graph break:

@torch._dynamo.nonstrict_tracedefget_magic_num():# This explicit graph break call is meant to emulate any kind of Dynamo# graph break, e.g., the function is implemented in C, or uses some python# language feature Dynamo doesn't yet support.torch._dynamo.graph_break()returntorch.tensor([42])@torch.compile(fullgraph=True)deffunc(x):n=get_magic_num()returnx+nprint(func(torch.rand(10)))# No graph break and no error.
tensor([42.6048, 42.6620, 42.3730, 42.1969, 42.7525, 42.4637, 42.1544, 42.3485,        42.8316, 42.8360])

Note that one can use it inside atorch.compile’d region as well:

defget_magic_num():# This explicit graph break call is meant to emulate any kind of Dynamo# graph break, e.g., the function is implemented in C, or uses some python# language feature Dynamo doesn't yet support.torch._dynamo.graph_break()returntorch.tensor([42])@torch.compile(fullgraph=True)deffunc(x):n=torch._dynamo.nonstrict_trace(get_magic_num)()returnx+nprint(func(torch.rand(10)))# No graph break and no error.
tensor([42.8445, 42.3965, 42.4669, 42.6729, 42.7607, 42.7630, 42.2933, 42.7238,        42.9972, 42.4853])