torch.export API Reference#
Created On: Jul 17, 2025 | Last Updated On: Jul 17, 2025
- torch.export.export(mod,args,kwargs=None,*,dynamic_shapes=None,strict=False,preserve_module_call_signature=(),prefer_deferred_runtime_asserts_over_guards=False)[source]#
export()takes any nn.Module along with example inputs, and produces a traced graph representingonly the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,which can subsequently be executed with different inputs or serialized. Thetraced graph (1) produces normalized operators in the functional ATen operator set(as well as any user-specified custom operators), (2) has eliminated all Python controlflow and data structures (with certain exceptions), and (3) records the set ofshape constraints needed to show that this normalization and control-flow eliminationis sound for future inputs.Soundness Guarantee
While tracing,
export()takes note of shape-related assumptionsmade by the user program and the underlying PyTorch operator kernels.The outputExportedProgramis considered valid only when theseassumptions hold true.Tracing makes assumptions on the shapes (not values) of input tensors.Such assumptions must be validated at graph capture time for
export()to succeed. Specifically:Assumptions on static shapes of input tensors are automatically validated without additional effort.
Assumptions on dynamic shape of input tensors require explicit specificationby using the
Dim()API to construct dynamic dimensions and by associatingthem with example inputs through thedynamic_shapesargument.
If any assumption can not be validated, a fatal error will be raised. When that happens,the error message will include suggested fixes to the specification that are neededto validate the assumptions. For example
export()might suggest thefollowing fix to the definition of a dynamic dimensiondim0_x, say appearing in theshape associated with inputx, that was previously defined asDim("dim0_x"):dim=Dim("dim0_x",max=5)
This example means the generated code requires dimension 0 of input
xto be lessthan or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimensiondefinitions and then copy them verbatim into your code without needing to change thedynamic_shapesargument to yourexport()call.- Parameters
mod (Module) – We will trace the forward method of this module.
kwargs (Optional[Mapping[str,Any]]) – Optional example keyword inputs.
dynamic_shapes (Optional[Union[dict[str,Any],tuple[Any,...],list[Any]]]) –
An optional argument where the type should either be:1) a dict from argument names of
fto their dynamic shape specifications,2) a tuple that specifies dynamic shape specifications for each input in original order.If you are specifying dynamism on keyword args, you will need to pass them in the order thatis defined in the original function signature.The dynamic shape of a tensor argument can be specified as either(1) a dict from dynamic dimension indices to
Dim()types, where it isnot required to include static dimension indices in this dict, but when they are,they should be mapped to None; or (2) a tuple / list ofDim()types or None,where theDim()types correspond to dynamic dimensions, and static dimensionsare denoted by None. Arguments that are dicts or tuples / lists of tensors arerecursively specified by using mappings or sequences of contained specifications.strict (bool) – When disabled (default), the export function will trace the program throughPython runtime, which by itself will not validate some of the implicit assumptionsbaked into the graph. It will still validate most critical assumptions like shapesafety. When enabled (by setting
strict=True), the export function will tracethe program through TorchDynamo which will ensure the soundness of the resultinggraph. TorchDynamo has limited Python feature coverage, thus you may experience moreerrors. Note that toggling this argument does not affect the resulting IR spec to bedifferent and the model will be serialized in the same way regardless of what valueis passed here.preserve_module_call_signature (tuple[str,...]) – A list of submodule paths for which the originalcalling conventions are preserved as metadata. The metadata will be used when callingtorch.export.unflatten to preserve the original calling conventions of modules.
- Returns
An
ExportedProgramcontaining the traced callable.- Return type
Acceptable input/output types
Acceptable types of inputs (for
argsandkwargs) and outputs include:Primitive types, i.e.
torch.Tensor,int,float,boolandstr.Dataclasses, but they must be registered by calling
register_dataclass()first.(Nested) Data structures comprising of
dict,list,tuple,namedtupleandOrderedDictcontaining all above types.
- classtorch.export.ExportedProgram(root,graph,graph_signature,state_dict,range_constraints,module_call_graph,example_inputs=None,constants=None,*,verifiers=None)[source]#
Package of a program from
export(). It containsantorch.fx.Graphthat represents Tensor computation, a state_dict containingtensor values of all lifted parameters and buffers, and various metadata.You can call an ExportedProgram like the original callable traced by
export()with the same calling convention.To perform transformations on the graph, use
.moduleproperty to accessantorch.fx.GraphModule. You can then useFX transformationto rewrite the graph. Afterwards, you can simply useexport()again to construct a correct ExportedProgram.- buffers()[source]#
Returns an iterator over original module buffers.
Warning
This API is experimental and isNOT backward-compatible.
- propertycall_spec#
Warning
This API is experimental and isNOT backward-compatible.
- propertyconstants#
Warning
This API is experimental and isNOT backward-compatible.
- propertyexample_inputs#
Warning
This API is experimental and isNOT backward-compatible.
- propertygraph#
Warning
This API is experimental and isNOT backward-compatible.
- propertygraph_module#
Warning
This API is experimental and isNOT backward-compatible.
- propertygraph_signature#
Warning
This API is experimental and isNOT backward-compatible.
- module(check_guards=True)[source]#
Returns a self contained GraphModule with all the parameters/buffers inlined.
Whencheck_guards=True (default), a_guards_fn submodule is generatedand a call to a_guards_fn submodule is inserted right after placeholdersin the graph. This module checks guards on inputs.
Whencheck_guards=False, a subset of these checks are performed by aforward pre-hook on the graph module. No_guards_fn submodule is generated.
- Return type
- propertymodule_call_graph#
Warning
This API is experimental and isNOT backward-compatible.
- named_buffers()[source]#
Returns an iterator over original module buffers, yieldingboth the name of the buffer as well as the buffer itself.
Warning
This API is experimental and isNOT backward-compatible.
- Return type
- named_parameters()[source]#
Returns an iterator over original module parameters, yieldingboth the name of the parameter as well as the parameter itself.
Warning
This API is experimental and isNOT backward-compatible.
- Return type
- parameters()[source]#
Returns an iterator over original module’s parameters.
Warning
This API is experimental and isNOT backward-compatible.
- propertyrange_constraints#
Warning
This API is experimental and isNOT backward-compatible.
- run_decompositions(decomp_table=None,decompose_custom_triton_ops=False)[source]#
Run a set of decompositions on the exported program and returns a newexported program. By default we will run the Core ATen decompositions toget operators in theCore ATen Operator Set.
For now, we do not decompose joint graphs.
- Parameters
decomp_table (Optional[dict[torch._ops.OperatorBase,Callable]]) – An optional argument that specifies decomp behaviour for Aten ops(1) If None, we decompose to core aten decompositions(2) If empty, we don’t decompose any operator
- Return type
Some examples:
If you don’t want to decompose anything
ep=torch.export.export(model,...)ep=ep.run_decompositions(decomp_table={})
If you want to get a core aten operator set except for certain operator, you can do following:
ep=torch.export.export(model,...)decomp_table=torch.export.default_decompositions()decomp_table[your_op]=your_custom_decompep=ep.run_decompositions(decomp_table=decomp_table)
- propertystate_dict#
Warning
This API is experimental and isNOT backward-compatible.
- propertytensor_constants#
Warning
This API is experimental and isNOT backward-compatible.
- propertyverifiers#
Warning
This API is experimental and isNOT backward-compatible.
- classtorch.export.dynamic_shapes.AdditionalInputs[source]#
Infers dynamic_shapes based on additional inputs.
This is useful particularly for deployment engineers who, on the one hand, mayhave access to ample testing or profiling data that can provide a fair sense ofrepresentative inputs for a model, but on the other hand, may not know enoughabout the model to guess which input shapes should be dynamic.
Input shapes that are different than the original are considered dynamic; conversely,those that are the same as the original are considered static. Moreover, we verifythat the additional inputs are valid for the exported program. This guarantees thattracing with them instead of the original would have generated the same graph.
Example:
args0,kwargs0=...# example inputs for export# other representative inputs that the exported program will run ondynamic_shapes=torch.export.AdditionalInputs()dynamic_shapes.add(args1,kwargs1)...dynamic_shapes.add(argsN,kwargsN)torch.export(...,args0,kwargs0,dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m,args,kwargs=None)[source]#
Infers a
dynamic_shapes()pytree structure by merging shapes of theoriginal inputargs()andkwargs()and of each additional inputargs and kwargs.
- classtorch.export.dynamic_shapes.Dim(name,*,min=None,max=None)[source]#
The
Dimclass allows users to specify dynamism in their exportedprograms. By marking a dimension with aDim, the compiler associates thedimension with a symbolic integer containing a dynamic range.The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes:
Dim.AUTO,Dim.DYNAMIC,Dim.STATIC), or named Dims (i.e.Dim("name",min=1,max=2)).Dim hints provide the lowest barrier to exportability, with the user onlyneeding to specify if a dimension if dynamic, static, or left for thecompiler to decide (
Dim.AUTO). The export process will automaticallyinfer the remaining constraints on min/max ranges and relationships betweendimensions.Example:
classFoo(nn.Module):defforward(self,x,y):assertx.shape[0]==4asserty.shape[0]>=16returnx@yx=torch.randn(4,8)y=torch.randn(8,16)dynamic_shapes={"x":{0:Dim.AUTO,1:Dim.AUTO},"y":{0:Dim.AUTO,1:Dim.AUTO},}ep=torch.export(Foo(),(x,y),dynamic_shapes=dynamic_shapes)
Here, export would raise an exception if we replaced all uses of
Dim.AUTOwithDim.DYNAMIC,asx.shape[0]is constrained to be static by the model.More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler,e.g.
(x.shape[0]+y.shape[1])%4==0, to be raised if runtime inputs do not satisfy such constraints.You may also specify min-max bounds for Dim hints, e.g.
Dim.AUTO(min=16,max=32),Dim.DYNAMIC(max=64),with the compiler inferring the remaining constraints within the ranges. An exception will be raised ifthe valid range is entirely outside the user-specified range.Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compilerinfers constraints that do not match the user specification. For example, exporting the previousmodel, the user would need the following
dynamic_shapesargument:s0=Dim("s0")s1=Dim("s1",min=16)dynamic_shapes={"x":{0:4,1:s0},"y":{0:s0,1:s1},}ep=torch.export(Foo(),(x,y),dynamic_shapes=dynamic_shapes)
Named Dims also allow specification of relationships between dimensions, upto univariate linear relations. For example, the following indicates onedimension is a multiple of another plus 4:
s0=Dim("s0")s1=3*s0+4
- classtorch.export.dynamic_shapes.ShapesCollection[source]#
Builder for dynamic_shapes.Used to assign dynamic shape specifications to tensors that appear in inputs.
This is useful particularly when
args()is a nested input structure, and it’seasier to index the input tensors, than to replicate the structure ofargs()inthedynamic_shapes()specification.Example:
args={"x":tensor_x,"others":[tensor_y,tensor_z]}dim=torch.export.Dim(...)dynamic_shapes=torch.export.ShapesCollection()dynamic_shapes[tensor_x]=(dim,dim+1,8)dynamic_shapes[tensor_y]={0:dim*2}# This is equivalent to the following (now auto-generated):# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}torch.export(...,args,dynamic_shapes=dynamic_shapes)
To specify dynamism for integers, we need to first wrap the integers using_IntWrapper so that we have a “unique identification tag” for each integer.
Example:
args={"x":tensor_x,"others":[int_x,int_y]}# Wrap all ints with _IntWrappermapped_args=pytree.tree_map_only(int,lambdaa:_IntWrapper(a),args)dynamic_shapes=torch.export.ShapesCollection()dynamic_shapes[tensor_x]=(dim,dim+1,8)dynamic_shapes[mapped_args["others"][0]]=Dim.DYNAMIC# This is equivalent to the following (now auto-generated):# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}torch.export(...,args,dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m,args,kwargs=None)[source]#
Generates the
dynamic_shapes()pytree structure according toargs()andkwargs().
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg,dynamic_shapes)[source]#
When exporting with
dynamic_shapes(), export may fail with a ConstraintViolation error if the specificationdoesn’t match the constraints inferred from tracing the model. The error message may provide suggested fixes -changes that can be made todynamic_shapes()to export successfully.Example ConstraintViolation error message:
Suggestedfixes:dim=Dim('dim',min=3,max=6)# this just refines the dim's rangedim=4# this specializes to a constantdy=dx+1# dy was specified as an independent dim, but is actually tied to dx with this relation
This is a helper function that takes the ConstraintViolation error message and the original
dynamic_shapes()spec,and returns a newdynamic_shapes()spec that incorporates the suggested fixes.Example usage:
try:ep=export(mod,args,dynamic_shapes=dynamic_shapes)excepttorch._dynamo.exc.UserErrorasexc:new_shapes=refine_dynamic_shapes_from_suggested_fixes(exc.msg,dynamic_shapes)ep=export(mod,args,dynamic_shapes=new_shapes)
- torch.export.save(ep,f,*,extra_files=None,opset_version=None,pickle_protocol=2)[source]#
Warning
Under active development, saved files may not be usable in newer versionsof PyTorch.
Saves an
ExportedProgramto a file-like object. It can then beloaded using the Python APItorch.export.load.- Parameters
ep (ExportedProgram) – The exported program to save.
f (str |os.PathLike[str]|IO[bytes]) – implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str,Any]]) – Map from filename to contentswhich will be stored as part of f.
opset_version (Optional[Dict[str,int]]) – A map of opset namesto the version of this opset
pickle_protocol (int) – can be specified to override the default protocol
Example:
importtorchimportioclassMyModule(torch.nn.Module):defforward(self,x):returnx+10ep=torch.export.export(MyModule(),(torch.randn(5),))# Save to filetorch.export.save(ep,"exported_program.pt2")# Save to io.BytesIO bufferbuffer=io.BytesIO()torch.export.save(ep,buffer)# Save with extra filesextra_files={"foo.txt":b"bar".decode("utf-8")}torch.export.save(ep,"exported_program.pt2",extra_files=extra_files)
- torch.export.load(f,*,extra_files=None,expected_opset_version=None)[source]#
Warning
Under active development, saved files may not be usable in newer versionsof PyTorch.
Loads an
ExportedProgrampreviously saved withtorch.export.save.- Parameters
f (str |os.PathLike[str]|IO[bytes]) – A file-like object (has toimplement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str,Any]]) – The extra filenames given inthis map would be loaded and their content would be stored in theprovided map.
expected_opset_version (Optional[Dict[str,int]]) – A map of opset namesto expected opset versions
- Returns
An
ExportedProgramobject- Return type
Example:
importtorchimportio# Load ExportedProgram from fileep=torch.export.load("exported_program.pt2")# Load ExportedProgram from io.BytesIO objectwithopen("exported_program.pt2","rb")asf:buffer=io.BytesIO(f.read())buffer.seek(0)ep=torch.export.load(buffer)# Load with extra files.extra_files={"foo.txt":""}# values will be replaced with dataep=torch.export.load("exported_program.pt2",extra_files=extra_files)print(extra_files["foo.txt"])print(ep(torch.randn(5)))
- torch.export.pt2_archive._package.package_pt2(f,*,exported_programs=None,aoti_files=None,extra_files=None,opset_version=None,pickle_protocol=2)[source]#
Saves the artifacts to a PT2Archive format. The artifact can then be loadedusing
load_pt2.- Parameters
f (str |os.PathLike[str]|IO[bytes]) – A file-like object (has toimplement write and flush) or a string containing a file name.
exported_programs (Union[ExportedProgram,dict[str,ExportedProgram]]) – The exported program to save, or a dictionary mapping model name to anexported program to save. The exported program will be saved undermodels/*.json. If only one ExportedProgram is specified, this willautomatically be named “model”.
aoti_files (Union[list[str],dict[str,list[str]]]) – A list of filesgenerated by AOTInductor via
torch._inductor.aot_compile(...,{"aot_inductor.package":True}),or a dictionary mapping model name to its AOTInductor generated files.If only one set of files is specified, this will automatically be named“model”.extra_files (Optional[Dict[str,Any]]) – Map from filename to contentswhich will be stored as part of the pt2.
opset_version (Optional[Dict[str,int]]) – A map of opset namesto the version of this opset
pickle_protocol (int) – can be specified to override the default protocol
- Return type
- torch.export.pt2_archive._package.load_pt2(f,*,expected_opset_version=None,run_single_threaded=False,num_runners=1,device_index=-1,load_weights_from_disk=False)[source]#
Loads all the artifacts previously saved with
package_pt2.- Parameters
f (str |os.PathLike[str]|IO[bytes]) – A file-like object (has toimplement write and flush) or a string containing a file name.
expected_opset_version (Optional[Dict[str,int]]) – A map of opset namesto expected opset versions
num_runners (int) – Number of runners to load AOTInductor artifacts
run_single_threaded (bool) – Whether the model should be run withoutthread synchronization logic. This is useful to avoid conflicts withCUDAGraphs.
device_index (int) – The index of the device to which the PT2 package isto be loaded. By default,device_index=-1 is used, which correspondsto the devicecuda when using CUDA. Passingdevice_index=1 wouldload the package tocuda:1, for example.
- Returns
A
PT2ArchiveContentsobject which contains all the objects in the PT2.- Return type
PT2ArchiveContents
- torch.export.draft_export(mod,args,kwargs=None,*,dynamic_shapes=None,preserve_module_call_signature=(),strict=False,prefer_deferred_runtime_asserts_over_guards=False)[source]#
A version of torch.export.export which is designed to consistently producean ExportedProgram, even if there are potential soundness issues, and togenerate a report listing the issues found.
- Return type
- classtorch.export.unflatten.FlatArgsAdapter[source]#
Adapts input arguments with
input_specto aligntarget_spec.
- classtorch.export.unflatten.InterpreterModule(graph,ty=None)[source]#
A module that uses torch.fx.Interpreter to execute instead of the usualcodegen that GraphModule uses. This provides better stack trace informationand makes it easier to debug execution.
- classtorch.export.unflatten.InterpreterModuleDispatcher(attrs,call_modules)[source]#
A module that carries a sequence of InterpreterModules corresponding toa sequence of calls of that module. Each call to the module dispatchesto the next InterpreterModule, and wraps back around after the last.
- torch.export.unflatten.unflatten(module,flat_args_adapter=None)[source]#
Unflatten an ExportedProgram, producing a module with the same modulehierarchy as the original eager module. This can be useful if you are tryingto use
torch.exportwith another system that expects a modulehierarchy instead of the flat graph thattorch.exportusually produces.Note
The args/kwargs of unflattened modules will not necessarily matchthe eager module, so doing a module swap (e.g.
self.submod=new_mod) will not necessarily work. If you need to swap a module out, youneed to set thepreserve_module_call_signatureparameter oftorch.export.export().- Parameters
module (ExportedProgram) – The ExportedProgram to unflatten.
flat_args_adapter (Optional[FlatArgsAdapter]) – Adapt flat args if input TreeSpec does not match with exported module’s.
- Returns
An instance of
UnflattenedModule, which has the same modulehierarchy as the original eager module pre-export.- Return type
UnflattenedModule
- torch.export.register_dataclass(cls,*,serialized_type_name=None)[source]#
Registers a dataclass as a valid input/output type for
torch.export.export().- Parameters
Example:
importtorchfromdataclassesimportdataclass@dataclassclassInputDataClass:feature:torch.Tensorbias:int@dataclassclassOutputDataClass:res:torch.Tensortorch.export.register_dataclass(InputDataClass)torch.export.register_dataclass(OutputDataClass)classMod(torch.nn.Module):defforward(self,x:InputDataClass)->OutputDataClass:res=x.feature+x.biasreturnOutputDataClass(res=res)ep=torch.export.export(Mod(),(InputDataClass(torch.ones(2,2),1),))print(ep)
- classtorch.export.decomp_utils.CustomDecompTable[source]#
This is a custom dictionary that is specifically used for handling decomp_table in export.The reason we need this is because in the new world, you can onlydelete an op from decomptable to preserve it. This is problematic for custom ops because we don’t know when the customop will actually be loaded to the dispatcher. As a result, we need to record the custom ops operationsuntil we really need to materialize it (which is when we run decomposition pass.)
- Invariants we hold are:
All aten decomp is loaded at the init time
We materialize ALL ops when user ever reads from the table to make it more likelythat dispatcher picks up the custom op.
If it is write operation, we don’t necessarily materialize
We load the final time during export, right before calling run_decompositions()
- torch.export.passes.move_to_device_pass(ep,location)[source]#
Move the exported program to the given device.
- Parameters
ep (ExportedProgram) – The exported program to move.
location (Union[torch.device,str,Dict[str,str]]) – The device to move the exported program to.If a string, it is interpreted as a device name.If a dict, it is interpreted as a mapping fromthe existing device to the intended one
- Returns
The moved exported program.
- Return type
- classtorch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#
Context manager for reading a PT2 archive.
- classtorch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#
Context manager for writing a PT2 archive.
- write_bytes(name,data)[source]#
Write a bytes object to the archive.name: The destination file inside the archive.data: The bytes object to write.
- write_file(name,file_path)[source]#
Copy a file into the archive.name: The destination file inside the archive.file_path: The source file on disk.
- torch.export.pt2_archive.is_pt2_package(serialized_model)[source]#
Check if the serialized model is a PT2 Archive package.
- Return type
- classtorch.export.exported_program.ModuleCallEntry(fqn:str,signature:Optional[torch.export.exported_program.ModuleCallSignature]=None)[source]#
- classtorch.export.exported_program.ModuleCallSignature(inputs:list[Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.SymFloatArgument,torch.export.graph_signature.SymBoolArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument]],outputs:list[Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.SymFloatArgument,torch.export.graph_signature.SymBoolArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument]],in_spec:torch.utils._pytree.TreeSpec,out_spec:torch.utils._pytree.TreeSpec,forward_arg_names:Optional[list[str]]=None)[source]#
- torch.export.exported_program.default_decompositions()[source]#
This is the default decomposition table which contains decomposition ofall ATEN operators to core aten opset. Use this API together with
run_decompositions()- Return type
- classtorch.export.custom_obj.ScriptObjectMeta(constant_name,class_fqn)[source]#
Metadata which is stored on nodes representing ScriptObjects.
- classtorch.export.graph_signature.ConstantArgument(name:str,value:Union[int,float,bool,str,NoneType])[source]#
- classtorch.export.graph_signature.CustomObjArgument(name:str,class_fqn:str,fake_val:Optional[torch._library.fake_class_registry.FakeScriptObject]=None)[source]#
- classtorch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters:dict[str,str],gradients_to_user_inputs:dict[str,str],loss_output:str)[source]#
- classtorch.export.graph_signature.ExportGraphSignature(input_specs,output_specs)[source]#
ExportGraphSignaturemodels the input/output signature of Export Graph,which is a fx.Graph with stronger invariants guarantees.Export Graph is functional and does not access “states” like parametersor buffers within the graph via
getattrnodes. Instead,export()guarantees that parameters, buffers, and constant tensors are lifted out ofthe graph as inputs. Similarly, any mutations to buffers are not includedin the graph either, instead the updated values of mutated buffers aremodeled as additional outputs of Export Graph.The ordering of all inputs and outputs are:
Inputs=[*parameters_buffers_constant_tensors,*flattened_user_inputs]Outputs=[*mutated_inputs,*flattened_user_outputs]
e.g. If following module is exported:
classCustomModule(nn.Module):def__init__(self)->None:super(CustomModule,self).__init__()# Define a parameterself.my_parameter=nn.Parameter(torch.tensor(2.0))# Define two buffersself.register_buffer("my_buffer1",torch.tensor(3.0))self.register_buffer("my_buffer2",torch.tensor(4.0))defforward(self,x1,x2):# Use the parameter, buffers, and both inputs in the forward methodoutput=(x1+self.my_parameter)*self.my_buffer1+x2*self.my_buffer2# Mutate one of the buffers (e.g., increment it by 1)self.my_buffer2.add_(1.0)# In-place additionreturnoutputmod=CustomModule()ep=torch.export.export(mod,(torch.tensor(1.0),torch.tensor(2.0)))
Resulting Graph is non-functional:
graph():%p_my_parameter:[num_users=1]=placeholder[target=p_my_parameter]%b_my_buffer1:[num_users=1]=placeholder[target=b_my_buffer1]%b_my_buffer2:[num_users=2]=placeholder[target=b_my_buffer2]%x1:[num_users=1]=placeholder[target=x1]%x2:[num_users=1]=placeholder[target=x2]%add:[num_users=1]=call_function[target=torch.ops.aten.add.Tensor](args=(%x1,%p_my_parameter),kwargs={})%mul:[num_users=1]=call_function[target=torch.ops.aten.mul.Tensor](args=(%add,%b_my_buffer1),kwargs={})%mul_1:[num_users=1]=call_function[target=torch.ops.aten.mul.Tensor](args=(%x2,%b_my_buffer2),kwargs={})%add_1:[num_users=1]=call_function[target=torch.ops.aten.add.Tensor](args=(%mul,%mul_1),kwargs={})%add_:[num_users=0]=call_function[target=torch.ops.aten.add_.Tensor](args=(%b_my_buffer2,1.0),kwargs={})return(add_1,)
Resulting ExportGraphSignature of the non-functional Graph would be:
# inputsp_my_parameter:PARAMETERtarget='my_parameter'b_my_buffer1:BUFFERtarget='my_buffer1'persistent=Trueb_my_buffer2:BUFFERtarget='my_buffer2'persistent=Truex1:USER_INPUTx2:USER_INPUT# outputsadd_1:USER_OUTPUT
To get a functional Graph, you can use
run_decompositions():mod=CustomModule()ep=torch.export.export(mod,(torch.tensor(1.0),torch.tensor(2.0)))ep=ep.run_decompositions()
Resulting Graph is functional:
graph():%p_my_parameter:[num_users=1]=placeholder[target=p_my_parameter]%b_my_buffer1:[num_users=1]=placeholder[target=b_my_buffer1]%b_my_buffer2:[num_users=2]=placeholder[target=b_my_buffer2]%x1:[num_users=1]=placeholder[target=x1]%x2:[num_users=1]=placeholder[target=x2]%add:[num_users=1]=call_function[target=torch.ops.aten.add.Tensor](args=(%x1,%p_my_parameter),kwargs={})%mul:[num_users=1]=call_function[target=torch.ops.aten.mul.Tensor](args=(%add,%b_my_buffer1),kwargs={})%mul_1:[num_users=1]=call_function[target=torch.ops.aten.mul.Tensor](args=(%x2,%b_my_buffer2),kwargs={})%add_1:[num_users=1]=call_function[target=torch.ops.aten.add.Tensor](args=(%mul,%mul_1),kwargs={})%add_2:[num_users=1]=call_function[target=torch.ops.aten.add.Tensor](args=(%b_my_buffer2,1.0),kwargs={})return(add_2,add_1)
Resulting ExportGraphSignature of the functional Graph would be:
# inputsp_my_parameter:PARAMETERtarget='my_parameter'b_my_buffer1:BUFFERtarget='my_buffer1'persistent=Trueb_my_buffer2:BUFFERtarget='my_buffer2'persistent=Truex1:USER_INPUTx2:USER_INPUT# outputsadd_2:BUFFER_MUTATIONtarget='my_buffer2'add_1:USER_OUTPUT
- propertybackward_signature:Optional[ExportBackwardSignature]#
- propertybuffers:Collection[str]#
- input_specs:list[torch.export.graph_signature.InputSpec]#
- propertyinput_tokens:Collection[str]#
- propertylifted_custom_objs:Collection[str]#
- propertylifted_tensor_constants:Collection[str]#
- propertynon_persistent_buffers:Collection[str]#
- output_specs:list[torch.export.graph_signature.OutputSpec]#
- propertyoutput_tokens:Collection[str]#
- propertyparameters:Collection[str]#
- classtorch.export.graph_signature.InputKind(value)[source]#
An enumeration.
- BUFFER=3#
- CONSTANT_TENSOR=4#
- CUSTOM_OBJ=5#
- PARAMETER=2#
- TOKEN=6#
- USER_INPUT=1#
- classtorch.export.graph_signature.InputSpec(kind:torch.export.graph_signature.InputKind,arg:Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.SymFloatArgument,torch.export.graph_signature.SymBoolArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument],target:Optional[str],persistent:Optional[bool]=None)[source]#
- classtorch.export.graph_signature.OutputKind(value)[source]#
An enumeration.
- BUFFER_MUTATION=3#
- GRADIENT_TO_PARAMETER=5#
- GRADIENT_TO_USER_INPUT=6#
- LOSS_OUTPUT=2#
- PARAMETER_MUTATION=4#
- TOKEN=8#
- USER_INPUT_MUTATION=7#
- USER_OUTPUT=1#
- classtorch.export.graph_signature.OutputSpec(kind:torch.export.graph_signature.OutputKind,arg:Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.SymFloatArgument,torch.export.graph_signature.SymBoolArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument],target:Optional[str])[source]#
- arg:Union[TensorArgument,SymIntArgument,SymFloatArgument,SymBoolArgument,ConstantArgument,CustomObjArgument,TokenArgument]#
- kind:OutputKind#