Rate this Page

Non-strict Tracing Programming Model#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Non-strict tracing is a way to trace Python code that is less strict than Dynamo, but may result in silent incorrectness.

  • Non-strict tracing runs a Python function and uses Python and PyTorch’s operator overloading capabilities to record what Tensor operations occurred during execution into a trace.

  • A function isnon-strict traceable if it complies with some constraints, namely, that the function ispure and does not directly manipulate Tensor.data_ptr().

  • Non-strict tracing mayspecialize on certain variables and treat them asconstants, baking the values of the variables into the trace.

torch.compile internals (make_fx, AOTDispatcher) usenon-strict tracing.torch._dynamo.nonstrict_trace can also be used intorch.compiled code to mark sections of code to be traced with non-strict tracing.Non-strict tracing runs a Python function and uses Python and PyTorch’s operator overloading capabilities to record what Tensor operations occurred during execution into a trace.

make_fx is the main entrypoint for non-strict tracing. For the following function, only the top branch is taken during execution of the inputs, so it captures a graph with only that branch.

fromtorch.fx.experimental.proxy_tensorimportmake_fxdeff(x):ifx.shape[0]>2:returnx**2/6else:returnx*3x=torch.randn(3)gm=make_fx(f,tracing_mode="fake")(x)gm.print_readable()
class f(torch.nn.Module):    def forward(self, x_1: "f32[3]"):        # No stacktrace found for following nodes        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None        return div
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None\n        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None\n        return div\n        '

Non-strict tracing differs from Dynamo (strict) tracing in thatit is unsafe, that is, given a function, it captures a graph of Tensor operations that may have different semantics than the original function.Given a Python function, Dynamo Tracing captures a graph of Tensor operations and residual bytecode that when combined give the same semantics as the Python function.

Pure Functions#

Non-strict tracing is sound only onpure functions, and thus only pure functions should be non-strict traced.

A pure function is a function with the following properties:

  • Determinism. Given the same inputs, the pure function will always return the same output.

  • No side effects. A pure function does not have any side effects such as modifying external state or performing I/O operations.

  • Explicit input/output. All the input data must be passed through the function parameters and all of the outputs are returned from the function.

Here are some examples of impure functions for which the captured graph behaves differently from the original function.

Example 1: No explicit input (e.g. accesses global tensor)#

var=torch.tensor(1)deffunction_with_global_access(y):returny+varx=torch.tensor([0,1,2])# _allow_non_fake_inputs=True is needed to capture the global variable# for demonstration purposes.gm=make_fx(function_with_global_access,tracing_mode="fake",_allow_non_fake_inputs=True)(x)# Non-strict Tracing captures the value of the global (1.)print("1. call function",function_with_global_access(x))print("1. call graph",gm(x))# However, after changing the global, the captured graph# produces a different result from the original functionvar=torch.tensor(2)print("2. call function",function_with_global_access(x))print("2. call graph",gm(x))# To capture a graph that can have a varying `var` tensor,# it must be an explicit input:deffunction_fixed(y,var):returny+varvar=torch.tensor(3)gm=make_fx(function_fixed,tracing_mode="fake")(x,var)print("3. call function",function_fixed(x,var))print("3. call graph",gm(x,var))var=torch.tensor(4)print("4. call function",function_fixed(x,var))print("4. call graph",gm(x,var))
1. call function tensor([1, 2, 3])1. call graph tensor([1, 2, 3])2. call function tensor([2, 3, 4])2. call graph tensor([1, 2, 3])3. call function tensor([3, 4, 5])3. call graph tensor([3, 4, 5])4. call function tensor([4, 5, 6])4. call graph tensor([4, 5, 6])

SeeSpecialization and Constants for an explanation of why.

Example 2: Side effect (printing)#

deffunction_with_side_effect(y):print(y)x=torch.tensor([0,1,2])_=function_with_side_effect(x)
tensor([0, 1, 2])

Runningf in Python prints a Tensor as a side effect.

gm=make_fx(function_with_side_effect,tracing_mode="fake")(x)
FakeTensor(..., size=(3,), dtype=torch.int64)

During non-strict tracing, this print occurs during the graph capture.

_=gm(x)

The graph does not store a call to theprint statement, so executing the graph doesn’t print anything.

Example 3: Side effect (input list mutation)#

lst=[]deffunction_with_input_list_mutation(lst):val=lst.pop()returnvalx=torch.tensor([0,1,2])y=torch.tensor([0,1,2])# Each time the function is executed, the list shrinks in sizelst=[x,y]function_with_input_list_mutation(lst)print("len(lst) after one call",len(lst))function_with_input_list_mutation(lst)print("len(lst) after two calls",len(lst))# With Non-strict Tracing, the length of the list shrinks during# the graph capture but not in invocations of the graph.lst=[x,y]gm=make_fx(function_with_input_list_mutation,tracing_mode="fake")(lst)print("len(lst) after graph capture",len(lst))gm(lst)print("len(lst) after one call to graph",len(lst))gm(lst)print("len(lst) after two calls to graph",len(lst))
len(lst) after one call 1len(lst) after two calls 0len(lst) after graph capture 2len(lst) after one call to graph 2len(lst) after two calls to graph 2

No direct data_ptr manipulation#

Directly manipulatingTensor.data_ptr is not non-strict traceable. The intuition behind this is that PyTorch is unable to tellhow you manipulated thedata_ptr.

importctypes# Create a tensor with a single elementtensor=torch.tensor([42],dtype=torch.int32)# Using int32 for simplicitydeffunction_with_data_ptr(tensor):# Get the data pointerptr=tensor.data_ptr()# Cast the pointer to a ctypes pointerctypes_ptr=ctypes.cast(ptr,ctypes.POINTER(ctypes.c_int32))# Increment the value at the pointerctypes_ptr.contents.value+=1returntensortry:make_fx(function_with_data_ptr,tracing_mode="fake")(tensor)exceptExceptionase:print(e)
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Specialization and Constants#

Non-strict tracing captures a graph that may be specialized on some values. What this means is the captured graph is only valid for these values. We say the graph treats those values asconstant.

All non-Tensor variables are treated as constant during Non-strict Tracing:

deff(x,y):returnx+yx=torch.tensor([0,1,2])y=3.14gm=make_fx(f,tracing_mode="fake")(x,y)gm.print_readable()
class f(torch.nn.Module):    def forward(self, x_1: "i64[3]", y_1):        # No stacktrace found for following nodes        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None        return add
'class f(torch.nn.Module):\n    def forward(self, x_1: "i64[3]", y_1):\n        # No stacktrace found for following nodes\n        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None\n        return add\n        '

3.14 is a constant in the graph.

Non-strict tracing will also specialize on properties of the input Tensors.

deff(x):ifx.shape[0]>2:returnx**2/6else:returnx*3x=torch.randn(3)gm=make_fx(f,tracing_mode="fake")(x)gm.print_readable()
class f(torch.nn.Module):    def forward(self, x_1: "f32[3]"):        # No stacktrace found for following nodes        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None        return div
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None\n        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None\n        return div\n        '

And it will also specialize on any variables not directly passed into the function:

var=torch.tensor(1)deff(x):returnx+yx=torch.randn(3)gm=make_fx(f,tracing_mode="fake")(x)gm.print_readable()
class f(torch.nn.Module):    def forward(self, x_1: "f32[3]"):        # No stacktrace found for following nodes        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None        return add
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None\n        return add\n        '