
Summary oftorch.export
torch.export.export()
performs ahead-of-time (AOT) compilation on a Python callable (e.g.,torch.nn.Module
) with aforward()
method, producing anExportedProgram
—a sound, functional graph of tensor computations.
If you're confusd between
torch.compile
andtorch.export
, check out this
Internally
torch.export()
internally uses:
- TorchDynamo: Traces PyTorch graphs at the bytecode level for broader code coverage.
- AOT Autograd: Functionalizes the graph and lowers it to
ATen
operators. - torch.fx.graph: Provides the graph representation for transformations.
Comparison Table
Component | Role |
---|---|
TorchDynamo | Bytecode-level tracing |
AOT Autograd | Graph functionalization, ATen lowering |
torch.fx.graph | Graph representation, transformations |
Example
importtorchfromtorch.exportimportexportclassMod(torch.nn.Module):defforward(self,x:torch.Tensor,y:torch.Tensor)->torch.Tensor:a=torch.sin(x)b=torch.cos(y)returna+bexample_args=(torch.randn(10,10),torch.randn(10,10))exported_program:torch.export.ExportedProgram=export(Mod(),args=example_args)# 🔥🔥
ExportedProgram: 🔥🔥 class GraphModule(torch.nn.Module): def forward(self, x:"f32[10, 10]", y:"f32[10, 10]"):# code: a = torch.sin(x) sin:"f32[10, 10]"= torch.ops.aten.sin.default(x)# code: b = torch.cos(y) cos:"f32[10, 10]"= torch.ops.aten.cos.default(y)# code: return a + b add:"f32[10, 10]"= torch.ops.aten.add.Tensor(sin, cos)return(add,) Graph signature: 🔥🔥 ExportGraphSignature(input_specs=[ InputSpec(kind=<InputKind.USER_INPUT: 1>,arg=TensorArgument(name='x'),target=None,persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>,arg=TensorArgument(name='y'),target=None,persistent=None)],output_specs=[ OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,arg=TensorArgument(name='add'),target=None)]) Range constraints:{}
2.ExportedProgram
ExportedProgram
consists of two main components:
- GraphModule
- Graph Signature
ExportedProgram::GraphModule
GraphModule
compiles every instruction into low-levelATen operations.
What is ATen?
ATen is fundamentally a tensor library, on top of which almost all other Python and C++ interfaces in PyTorch are built. It provides a core Tensor class, on which many hundreds of operations are defined. Most of these operations have both CPU and GPU implementations
Additionally, all arguments are lifted into the parameters of the forward() method.
ExportedProgram::Graph Signature
The graph signature is functional, meaning it has no side effects and will always produce the same output given the same input.
3. Strict vs Non-Strict Modes
Both modes eventually compile the model to atorch.fx.Graph
.
Non-Strict Mode:
- Requires Python runtime for compilation
- Runs in eager mode
- Uses tracing with
ProxyTensor
Strict Mode:
torch.Dynamo
inspects bytecode and compiles it- Potentially generates IR Graph with
cuda.graph
4.export_for_training()
Useexport_for_training()
for training with non-functional ops (e.g.,BatchNorm
with state updates). It creates a generic IR with all ATen operators for eager PyTorch Autograd, ideal for cases like PT2 Quantization. It can be converted to inference IR viarun_decompositions()
.
5. Dynamism
Dynamic shapes are supported usingDim()
to generate range constraints during compilation.
fromtorch.exportimportDimbatch=Dim("batch")dynamic_shapes={"x":{0:batch},"y":{0:batch}}exported_program:torch.export.ExportedProgram=export(Mod(),args=example_args,dynamic_shapes=dynamic_shapes)
ExportedProgram: class GraphModule(torch.nn.Module): ... Graph signature: ... Range constraints:{batch: VR[0, int_oo]}
6. Serialization
torch.export.save()
: Saves to *.pt2 formattorch.export.load()
: Loads the saved model
7. Specialization
Contrast with generalization.
Certain values (e.g., input shapes, Python primitives, container structures) are fixed as constants during export.
Effect: Static values enable constant folding operations with all static inputs can be precomputed and removed from runtime.
torch.export
fixes certain values as static constants in the graph:
- Tensor Shapes: Static by default unless marked dynamic with
dynamic_shapes
. - Python Primitives:
int
,float
, etc., are hardcoded unless usingSymInt
. - Python Containers: Lists, dictionaries, etc., have fixed structures at export.
Static inputs lead to precomputed results, simplifying the graph.
8. Limitations
Limitation::Graph Breaks
torch.export
may fail on untraceable code, requiring rewrites or extra info (unliketorch.compile
’s fallback).torch.Dynamo
reduces rewrites; useExportDB or non-strict mode for help.
Limitation::Missing Fake Kernels
Tracing needsFakeTensor
kernels for shape inference; missing kernels cause failures or errors.
Top comments(0)
For further actions, you may consider blocking this person and/orreporting abuse