torch.export Programming Model#
Created On: Dec 18, 2024 | Last Updated On: Jul 16, 2025
This document aims to explain the behaviors and capabilities oftorch.export.export(). It is intended to help build your intuitionfor howtorch.export.export() handles code.
Basics of Tracing#
torch.export.export() captures a graph representing your model bytracing its execution on “example” inputs and recording the PyTorch operationsand conditions observed along the traced path. This graph can then be runon different inputs as long as they satisfy the same conditions.
The basic output oftorch.export.export() is a single graph of PyTorchoperations, with associated metadata. The exact format of this output iscovered in theexport IR spec.
Strict vs. Non-Strict Tracing#
torch.export.export() provides two modes of tracing.
Innon-strict mode, we trace through the program using the normal Pythoninterpreter. Your code executes exactly as it would in eager mode; the onlydifference is that all Tensors are replaced byfake Tensors,which have shapes and other forms of metadata but no data, wrapped inProxy objects that record alloperations on them into a graph. We also captureconditions on Tensor shapesthat guard the correctness of the generated code.
Instrict mode, we first trace through the program usingTorchDynamo, a Python bytecodeanalysis engine. TorchDynamo does not actually execute your Python code.Instead, it symbolically analyzes it and builds a graph based on the results.On the one hand, this analysis allowstorch.export.export() to provideadditional guarantees on Python-level safety (beyond capturing conditions onTensor shapes, as in non-strict mode). On the other hand, not all Pythonfeatures are supported by this analysis.
Although currently the default mode of tracing is strict,we stronglyrecommend using non-strict, which will soon become the default.For most models, conditions on Tensor shapes are enough for soundness, andthe additional guarantees on Python-level safety have no impact; at the sametime, the possibility of hitting unsupported Python features in TorchDynamopresents an unnecessary risk.
In the rest of this document we assume we are tracing innon-strict mode;in particular, we assume thatall Python features are supported.
Values: Static vs. Dynamic#
A key concept in understanding the behavior oftorch.export.export() isthe difference betweenstatic anddynamic values.
Static Values#
Astatic value is a value that isfixed at export time and cannot changebetween executions of the exported program. When the value is encounteredduring tracing, we treat it as a constant and hard-code it into the graph.
When an operation is performed (e.g.x+y) and all inputs are static,the output of the operation is directly hard-coded into the graph and theoperation does not show up (i.e. it gets “constant-folded”).
When a value has been hard-coded into the graph, we say that the graph hasbeenspecialized to that value. For example:
importtorchclassMyMod(torch.nn.Module):defforward(self,x,y):z=y+7returnx+zm=torch.export.export(MyMod(),(torch.randn(1),3))print(m.graph_module.code)"""def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None return (add,)"""
Here, we provide3 as the traced value fory; it is treated as a staticvalue and added to7, burning in the static value10 in the graph.
Dynamic Values#
Adynamic value is one thatcan change from run to run. It behaves justlike a “normal” function argument: you can pass different inputs and expectyour function to do the right thing.
Which values are static vs. dynamic?#
Whether a value is static or dynamic depends on its type:
For Tensor:
Tensordata is treated as dynamic.
Tensorshapes can be treated by the system as static or dynamic.
By default, shapes of all input Tensors are considered static.The user can override this behavior for any input Tensor by specifyingadynamic shapefor it.
Tensors that are part of module state, i.e., parameters and buffers,always have static shapes.
Other forms of Tensormetadata (e.g.
device,dtype) are static.
Pythonprimitives (
int,float,bool,str,None) are static.There are dynamic variants for some primitive types (
SymInt,SymFloat,SymBool). Typically users do not have to deal with them.Users can specify integer inputs as dynamic by specifyingadynamic shapefor it.
For Pythonstandard containers (
list,tuple,dict,namedtuple):The structure (i.e., length for
listandtuplevalues, and keysequence fordictandnamedtuplevalues) is static.The contained elements have these rules applied to them recursively(basically thePyTree scheme)with leaves that are either Tensor or primitive types.
Otherclasses (including data classes) can be registered with PyTree(see below), and follow the same rules as the standard containers.
Input types#
Inputs will be treated as either static or dynamic, based on their type(as explained above).
A static input will get hard-coded into the graph, and passing a differentvalue at run time will result in an error. Recall that these are mostlyvalues of primitive types.
A dynamic input behaves like a “normal” function input. Recall that theseare mostly values of Tensor types.
By default, the types of inputs you can use for your program are:
Tensor
Python primitives (
int,float,bool,str,None)Python standard containers (
list,tuple,dict,namedtuple)
Custom Input Types (PyTree)#
In addition, you can also define your own (custom) class and use it as aninput type, but you will need to register such a class as a PyTree.
Here’s an example of using an utility to register a dataclass that is used asan input type.
@dataclassclassInput:f:torch.Tensorp:torch.Tensorimporttorch.utils._pytreeaspytreepytree.register_dataclass(Input)classM(torch.nn.Module):defforward(self,x:Input):returnx.f+1torch.export.export(M(),(Input(f=torch.ones(10,4),p=torch.zeros(10,4)),))
Optional input types#
For optional inputs to the program that are not passed in,torch.export.export() will specialize to their default values. As aresult, the exported program will require users to explicitly pass in allarguments, and will lose the defaulting behavior. For example:
classM(torch.nn.Module):defforward(self,x,y=None):ifyisnotNone:returny*xreturnx+x# Optional input is passed inep=torch.export.export(M(),(torch.randn(3,3),torch.randn(3,3)))print(ep)"""ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"): # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None return (mul,)"""# Optional input is not passed inep=torch.export.export(M(),(torch.randn(3,3),))print(ep)"""ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]", y): # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,)"""
Control Flow: Static vs. Dynamic#
Control flow is supported bytorch.export.export(). The behavior ofcontrol flow depends on whether the value you are branching on is static ordynamic.
Static Control Flow#
Python control flow over static values is supported transparently. (Recallthat static values include static shapes, so control flow over static shapesis also covered by this case.)
As mentioned above, we “burn in” static values, so the exported graph willnever see any control flow over static values.
In the case of anif statement, we will continue tracing the branch takenat export time. In the case of afor orwhile statement, we will continuetracing by unrolling the loop.
Dynamic Control Flow: Shape-Dependent vs. Data-Dependent#
When the value involved in a control flow is dynamic, it could depend ondynamic shapes or dynamic data. Given that the compiler traces withinformation on shapes rather than data, the implications on the programmingmodel are different in these cases.
Dynamic Shape-Dependent Control Flow#
When the value involved in a control flow is adynamic shape,in most caseswe will also know the concrete value of the dynamic shapeduring tracing: see the following section for more details on how thecompiler tracks this information.
In these cases we say that the control flow is shape-dependent.We use theconcrete value of the dynamic shape to evaluate the condition to eitherTrue orFalse and continue tracing (as discussed above), additionallyemitting a guard corresponding to the condition just evaluated.
Otherwise the control flow is considered data-dependent. We cannot evaluatethe condition to eitherTrue orFalse, so cannot continue tracing and have toraise an error at export time. See next section.
Dynamic Data-Dependent Control Flow#
Data-dependent control flow over dynamic values is supported, but you mustuse one of PyTorch’s explicit operators to continue tracing. Using Pythoncontrol flow statements over dynamic values is not permitted, because thecompiler cannot evaluate the conditions necessary to continue tracing andthus an error must be raised at export time.
We provideoperators to express general conditionals and loops over dynamicvalues, e.g.,torch.cond,torch.map. Note that you only need to use theseif you truly wantdata-dependent control flow.
Here’s an example of anif statement on a data-dependent condition,x.sum()>0, wherex is an input Tensor, rewritten usingtorch.cond.Instead of having to decide which branch to trace, now both branches aretraced.
classM_old(torch.nn.Module):defforward(self,x):ifx.sum()>0:returnx.sin()else:returnx.cos()classM_new(torch.nn.Module):defforward(self,x):returntorch.cond(pred=x.sum()>0,true_fn=lambdax:x.sin(),false_fn=lambdax:x.cos(),operands=(x,),)
A special case of data-dependent control flow is where it involves adata-dependent dynamic shape:typically, the shape of some intermediate Tensor that depends on input datarather than on input shapes (thus not shape-dependent). Instead of using acontrol flow operator, in this case you can provide an assertion that decideswhether the condition isTrue orFalse. Given such an assertion, we cancontinue tracing, emitting a guard as above.
We provideoperators to express assertions on dynamic shapes, e.g.,torch._check. Note that you only need to use this when there is controlflow on data-dependent dynamic shapes.
Here’s an example of anif statement on a condition involving adata-dependent dynamic shape,nz.shape[0]>0, wherenz is the result ofcallingtorch.nonzero(), an operator whose output shape depends on inputdata. Instead of rewriting it, you can add an assertion usingtorch._checkto effectively decide which branch to trace.
classM_old(torch.nn.Module):defforward(self,x):nz=x.nonzero()ifnz.shape[0]>0:returnx.sin()else:returnx.cos()classM_new(torch.nn.Module):defforward(self,x):nz=x.nonzero()torch._check(nz.shape[0]>0)ifnz.shape[0]>0:returnx.sin()else:returnx.cos()
Basics of Symbolic Shapes#
During tracing, dynamic Tensor shapes and conditions over them are encoded as“symbolic expressions.” (In contrast, static Tensor shapes and conditionsover them are simplyint andbool values.)
Asymbol is like a variable; it describes a dynamic Tensor shape.
As tracing proceeds, shapes of intermediate Tensors may be described by moregeneral expressions, typically involving integer arithmetic operators. Thisis becausefor most PyTorch operators, shapes of output Tensors can bedescribed as functions of shapes of input Tensors. For example, the shape ofthe output oftorch.cat() is the sum of the shapes of its inputs.
Moreover, as we encounter control flow in the program, we create booleanexpressions, typically involving relational operators, describing conditionsalong the traced path. Theseexpressions are evaluated to decide which pathto trace through the program, and recorded in ashape environmentto guard the correctness of the traced path and to evaluate subsequentlycreated expressions.
We briefly introduce these subsystems next.
Fake Implementations of PyTorch Operators#
Recall that during tracing, we are executing the program withfake Tensors,which have no data. In general we cannot call the actual implementations ofPyTorch operators with fake Tensors. Thus each operator needs to have anadditional fake (a.k.a. “meta”) implementation, which inputs and outputs fakeTensors, that matches the behavior of the actual implementation in terms ofshapes and other forms of metadata carried by fake Tensors.
For example, note how the fake implementation oftorch.index_select()computes the shape of the output using the shape of the input (while ignoringinput data and returning empty output data).
defmeta_index_select(self,dim,index):result_size=list(self.size())ifself.dim()>0:result_size[dim]=index.numel()returnself.new_empty(result_size)
Shape Propagation: Backed vs. Unbacked Dynamic Shapes#
Shapes are propagated using fake implementations of PyTorch operators.
A key concept to understand the propagation of dynamic shapes in particularis the difference betweenbacked andunbacked dynamic shapes: we know theconcrete values of the former but not the latter.
Propagation of shapes, including tracking backed and unbacked dynamic shapes,proceeds as follows:
The shapes of Tensors representing inputs can be static or dynamic. Whendynamic, they are described by symbols; moreover,such symbols are backedsince we also know their concrete values given the “real” example inputsprovided by the user at export time.
The output shape of an operator is computed by its fake implementation, andis either static or dynamic. When dynamic, in general it is described by asymbolic expression. Moreover:
If the output shape depends only on input shapes, it is either static orbacked dynamic whenever the input shapes are all static or backed dynamic.
On the other hand,if the output shape depends on input data, it isnecessarily dynamic, and moreover,because we cannot know its concretevalue it is unbacked.
Control Flow: Guards and Assertions#
When a condition on shapes is encountered, it either involves only staticshapes, in which case it is abool, or it involves dynamic shapes, in whichcase it is a symbolic boolean expression. For the latter:
When the condition involves only backed dynamic shapes, we can use theconcrete values of those dynamic shapes to evaluate the condition to
TrueorFalse. We can then add a guard to the shape environment that statesthat the corresponding symbolic boolean expression isTrueorFalse,and continue tracing.Otherwise the condition involves unbacked dynamic shapes. In general wecannot evaluate such a condition without additional information; thus wecannot continue tracing, and we must raise an error at export time. Theuser is expected to use an explicit PyTorch operator for tracing tocontinue. This information is added as a guard in the shape environment,and can also possibly help evaluate other subsequently encounteredconditions to
TrueorFalse.
Once the model is exported,any guards on backed dynamic shapes can beunderstood as conditions on input dynamic shapes. These are verified againsta dynamic shape specification that must have been provided to export,describing conditions on dynamic shapes that not only example inputs but alsoall future inputs are expected to satisfy for the generated code to becorrect. More precisely, the dynamic shape specification must logically implythe generated guards, otherwise an error is raised at export time (along withsuggested fixes to the dynamic shape specification). On the other hand, whenthere are no generated guards on backed dynamic shapes (in particular, whenall shapes are static) no dynamic shape specification needs to be provided toexport. In general, the dynamic shape specification is converted to runtimeassertions on the inputs of the generated code.
Finally,any guards on unbacked dynamic shapes are converted to “inline”runtime assertions. These are added in the generated code at the locationswhere those unbacked dynamic shapes were created: typically, right afterdata-dependent operator calls.
Allowed PyTorch operators#
All PyTorch operators are permitted.
Custom operators#
In addition, you can define and usecustom operators.Defining a custom operator includes defining a fake implementation for it,just like any other PyTorch operator (see previous section).
Here’s an example of a customsin operator that wraps NumPy, and itsregistered (trivial) fake implementation.
@torch.library.custom_op("mylib::sin",mutates_args=())defsin(x:Tensor)->Tensor:x_np=x.numpy()y_np=np.sin(x_np)returntorch.from_numpy(y_np)@torch.library.register_fake("mylib::sin")def_(x:Tensor)->Tensor:returntorch.empty_like(x)
Sometimes your custom operator’s fake implementation will involvedata-dependent shapes. Here’s how a fake implementation for a customnonzero might look like.
...@torch.library.register_fake("mylib::custom_nonzero")def_(x):nnz=torch.library.get_ctx().new_dynamic_size()shape=[nnz,x.dim()]returnx.new_empty(shape,dtype=torch.int64)
Module State: Reads vs. Updates#
Module states include parameters, buffers, and regular attributes.
A regular attribute can be of any type.
On the other hand, parameters and buffers are always Tensors.
Module states can be dynamic or static, based on their types as outlinedabove. For example,self.training is abool, which means it is static; onthe other hand, any parameter or buffer is dynamic.
Theshapes of any Tensors contained in module states cannot be dynamic, i.e.,those shapes are fixed at export time, and cannot change between executionsof the exported program.
Access rules#
All module states must be initialized. Accessing a module state that isnot already initialized causes an error to be raised at export time.
Reading module states is always permitted.
Updating module states is possible, but must follow the rules below:
A static regular attribute (e.g., of primitive type)can be updated.Reads and updates can be freely interleaved, and as expected, any readswill always see the values of the latest updates. Because these attributesare static, we will also burn the values in, so the generated code will nothave any instructions to actually “get” or “set” such attributes.
A dynamic regular attribute (e.g., of Tensor type)cannot be updated.To do so, it must be registered as a buffer during module initialization.
A buffer can be updated, where the updating can be in-place (e.g.,
self.buffer[:]=...) or not (e.g.,self.buffer=...).A parameter cannot be updated. Typically parameters are updated onlyduring training, not during inference. We recommend exporting with
torch.no_grad()to avoid parameter updates at export time.
Effects of functionalization#
Any dynamic module state that is read and/or updated is “lifted”(respectively) as an input and/or output of the generated code.
The exported program stores, along with the generated code, the initialvalues of parameters and buffers and the constant values of other Tensorattributes.