ExportDB#
ExportDB is a centralized dataset of supported and unsupported export cases.It is targeted towards users who want to understand specifically what types ofcode are supported, the subtleties of export, and how to modify their existingcode to be compatible with export. Note that this is not an exhaustive set ofeverything that is supported by exportdb, but it covers themost common and confusing use cases that users will run into.
If you have a feature that you think needs a stronger guarantee from us tosupport in export please create an issue in the pytorch/pytorch repo with a module:export tag.
Tags
Supported#
assume_constant_result#
Original source code:
# mypy: allow-untyped-defsimporttorchimporttorch._dynamoastorchdynamoclassAssumeConstantResult(torch.nn.Module):""" Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """@torchdynamo.assume_constant_resultdefget_item(self,y):returny.int().item()defforward(self,x,y):returnx[:self.get_item(y)]example_args=(torch.randn(3,2),torch.tensor(4))tags={"torch.escape-hatch"}model=AssumeConstantResult()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"i64[]"):slice_1:"f32[3, 2]"=torch.ops.aten.slice.Tensor(x,0,0,4);x=Nonereturn(slice_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsslice_1:USER_OUTPUTRangeconstraints:{}
autograd_function#
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defsimporttorchclassMyAutogradFunction(torch.autograd.Function):@staticmethoddefforward(ctx,x):returnx.clone()@staticmethoddefbackward(ctx,grad_output):returngrad_output+1classAutogradFunction(torch.nn.Module):""" TorchDynamo does not keep track of backward() on autograd functions. We recommend to use `allow_in_graph` to mitigate this problem. """defforward(self,x):returnMyAutogradFunction.apply(x)example_args=(torch.randn(3,2),)model=AutogradFunction()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):clone:"f32[3, 2]"=torch.ops.aten.clone.default(x);x=Nonereturn(clone,)Graphsignature:# inputsx:USER_INPUT# outputsclone:USER_OUTPUTRangeconstraints:{}
class_method#
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defsimporttorchclassClassMethod(torch.nn.Module):""" Class methods are inlined during tracing. """@classmethoddefmethod(cls,x):returnx+1def__init__(self)->None:super().__init__()self.linear=torch.nn.Linear(4,2)defforward(self,x):x=self.linear(x)returnself.method(x)*self.__class__.method(x)*type(self).method(x)example_args=(torch.randn(3,4),)model=ClassMethod()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,p_linear_weight:"f32[2, 4]",p_linear_bias:"f32[2]",x:"f32[3, 4]"):linear:"f32[3, 2]"=torch.ops.aten.linear.default(x,p_linear_weight,p_linear_bias);x=p_linear_weight=p_linear_bias=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1)mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add,add_1);add=add_1=Noneadd_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1);linear=Nonemul_1:"f32[3, 2]"=torch.ops.aten.mul.Tensor(mul,add_2);mul=add_2=Nonereturn(mul_1,)Graphsignature:# inputsp_linear_weight:PARAMETERtarget='linear.weight'p_linear_bias:PARAMETERtarget='linear.bias'x:USER_INPUT# outputsmul_1:USER_OUTPUTRangeconstraints:{}
cond_branch_class_method#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassMySubModule(torch.nn.Module):deffoo(self,x):returnx.cos()defforward(self,x):returnself.foo(x)classCondBranchClassMethod(torch.nn.Module):""" The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using class method in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """def__init__(self)->None:super().__init__()self.subm=MySubModule()defbar(self,x):returnx.sin()defforward(self,x):returncond(x.shape[0]<=2,self.subm.forward,self.bar,[x])example_args=(torch.randn(3),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchClassMethod()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3]"):sin:"f32[3]"=torch.ops.aten.sin.default(x);x=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{}
cond_branch_nested_function#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondBranchNestedFunction(torch.nn.Module):""" The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using nested function in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """defforward(self,x):deftrue_fn(x):definner_true_fn(y):returnx+yreturninner_true_fn(x)deffalse_fn(x):definner_false_fn(y):returnx-yreturninner_false_fn(x)returncond(x.shape[0]<10,true_fn,false_fn,[x])example_args=(torch.randn(3),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchNestedFunction()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3]"):add:"f32[3]"=torch.ops.aten.add.Tensor(x,x);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
cond_branch_nonlocal_variables#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondBranchNonlocalVariables(torch.nn.Module):""" The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. The code below will not work because capturing closure variables is not supported. ``` my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(y): nonlocal my_tensor_var, my_primitive_var return y + my_tensor_var + my_primitive_var def false_fn(y): nonlocal my_tensor_var, my_primitive_var return y - my_tensor_var - my_primitive_var return cond(x.shape[0] > 5, true_fn, false_fn, [x]) ``` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """defforward(self,x):my_tensor_var=x+100my_primitive_var=3.14deftrue_fn(x,y,z):returnx+y+zdeffalse_fn(x,y,z):returnx-y-zreturncond(x.shape[0]>5,true_fn,false_fn,[x,my_tensor_var,torch.tensor(my_primitive_var)],)example_args=(torch.randn(6),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchNonlocalVariables()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,c_lifted_tensor_0:"f32[]",x:"f32[6]"):add:"f32[6]"=torch.ops.aten.add.Tensor(x,100)lift_fresh_copy:"f32[]"=torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);c_lifted_tensor_0=Nonedetach_:"f32[]"=torch.ops.aten.detach_.default(lift_fresh_copy);lift_fresh_copy=Noneadd_1:"f32[6]"=torch.ops.aten.add.Tensor(x,add);x=add=Noneadd_2:"f32[6]"=torch.ops.aten.add.Tensor(add_1,detach_);add_1=detach_=Nonereturn(add_2,)Graphsignature:# inputsc_lifted_tensor_0:CONSTANT_TENSORtarget='lifted_tensor_0'x:USER_INPUT# outputsadd_2:USER_OUTPUTRangeconstraints:{}
cond_closed_over_variable#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondClosedOverVariable(torch.nn.Module):""" torch.cond() supports branches closed over arbitrary variables. """defforward(self,pred,x):deftrue_fn(val):returnx*2deffalse_fn(val):returnx-2returncond(pred,true_fn,false_fn,[x+1])example_args=(torch.tensor(True),torch.randn(3,2))tags={"torch.cond","python.closure"}model=CondClosedOverVariable()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pred:"b8[]",x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1);add=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(pred,true_graph_0,false_graph_0,(x,));pred=true_graph_0=false_graph_0=x=Nonegetitem:"f32[3, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,2);x=Nonereturn(mul,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(x,2);x=Nonereturn(sub,)Graphsignature:# inputspred:USER_INPUTx:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}
cond_operands#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch.exportimportDimx=torch.randn(3,2)y=torch.randn(2)dim0_x=Dim("dim0_x")classCondOperands(torch.nn.Module):""" The operands passed to cond() must be: - a list of tensors - match arguments of `true_fn` and `false_fn` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """defforward(self,x,y):deftrue_fn(x,y):returnx+ydeffalse_fn(x,y):returnx-yreturntorch.cond(x.shape[0]>2,true_fn,false_fn,[x,y])example_args=(x,y)tags={"torch.cond","torch.dynamic-shape",}extra_inputs=(torch.randn(2,2),torch.randn(2))dynamic_shapes={"x":{0:dim0_x},"y":None}model=CondOperands()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):#sym_size_int_1:"Sym(s77)"=torch.ops.aten.sym_size.int(x,0)gt:"Sym(s77 > 2)"=sym_size_int_1>2;sym_size_int_1=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(gt,true_graph_0,false_graph_0,(x,y));gt=true_graph_0=false_graph_0=x=y=Nonegetitem:"f32[s77, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):add:"f32[s77, 2]"=torch.ops.aten.add.Tensor(x,y);x=y=Nonereturn(add,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):sub:"f32[s77, 2]"=torch.ops.aten.sub.Tensor(x,y);x=y=Nonereturn(sub,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{s77:VR[0,int_oo]}
cond_predicate#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondPredicate(torch.nn.Module):""" The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """defforward(self,x):pred=x.dim()>2andx.shape[2]>10returncond(pred,lambdax:x.cos(),lambday:y.sin(),[x])example_args=(torch.randn(6,4,3),)tags={"torch.cond","torch.dynamic-shape",}model=CondPredicate()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[6, 4, 3]"):sin:"f32[6, 4, 3]"=torch.ops.aten.sin.default(x);x=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{}
constrain_as_size_example#
Original source code:
# mypy: allow-untyped-defsimporttorchclassConstrainAsSizeExample(torch.nn.Module):""" If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check_is_size is used for values that NEED to be used for constructing tensor. """defforward(self,x):a=x.item()torch._check_is_size(a)torch._check(a<=5)returntorch.zeros((a,5))example_args=(torch.tensor(4),)tags={"torch.dynamic-value","torch.escape-hatch",}model=ConstrainAsSizeExample()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"i64[]"):item:"Sym(u0)"=torch.ops.aten.item.default(x);x=None#sym_constrain_range_for_size_default=torch.ops.aten.sym_constrain_range_for_size.default(item);sym_constrain_range_for_size_default=Nonege_1:"Sym(u0 >= 0)"=item>=0_assert_scalar_default=torch.ops.aten._assert_scalar.default(ge_1,"Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");ge_1=_assert_scalar_default=Nonele_1:"Sym(u0 <= 5)"=item<=5_assert_scalar_default_1=torch.ops.aten._assert_scalar.default(le_1,"Runtime assertion failed for expression u0 <= 5 on node 'le_1'");le_1=_assert_scalar_default_1=Nonezeros:"f32[u0, 5]"=torch.ops.aten.zeros.default([item,5],device=device(type='cpu'),pin_memory=False);item=Nonereturn(zeros,)Graphsignature:# inputsx:USER_INPUT# outputszeros:USER_OUTPUTRangeconstraints:{u0:VR[0,5],u1:VR[0,5]}
constrain_as_value_example#
Original source code:
# mypy: allow-untyped-defsimporttorchclassConstrainAsValueExample(torch.nn.Module):""" If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check is used for values that don't need to be used for constructing tensor. """defforward(self,x,y):a=x.item()torch._check(a>=0)torch._check(a<=5)ifa<6:returny.sin()returny.cos()example_args=(torch.tensor(4),torch.randn(5,5))tags={"torch.dynamic-value","torch.escape-hatch",}model=ConstrainAsValueExample()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"i64[]",y:"f32[5, 5]"):item:"Sym(u0)"=torch.ops.aten.item.default(x);x=Nonege_1:"Sym(u0 >= 0)"=item>=0_assert_scalar_default=torch.ops.aten._assert_scalar.default(ge_1,"Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");ge_1=_assert_scalar_default=Nonele_1:"Sym(u0 <= 5)"=item<=5;item=None_assert_scalar_default_1=torch.ops.aten._assert_scalar.default(le_1,"Runtime assertion failed for expression u0 <= 5 on node 'le_1'");le_1=_assert_scalar_default_1=Nonesin:"f32[5, 5]"=torch.ops.aten.sin.default(y);y=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{u0:VR[0,5],u1:VR[0,5]}
decorator#
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defsimportfunctoolsimporttorchdeftest_decorator(func):@functools.wraps(func)defwrapper(*args,**kwargs):returnfunc(*args,**kwargs)+1returnwrapperclassDecorator(torch.nn.Module):""" Decorators calls are inlined into the exported function during tracing. """@test_decoratordefforward(self,x,y):returnx+yexample_args=(torch.randn(3,2),torch.randn(3,2))model=Decorator()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,y);x=y=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(add,1);add=Nonereturn(add_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}
dictionary#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDictionary(torch.nn.Module):""" Dictionary structures are inlined and flattened along tracing. """defforward(self,x,y):elements={}elements["x2"]=x*xy=y*elements["x2"]return{"y":y}example_args=(torch.randn(3,2),torch.tensor(4))tags={"python.data-structure"}model=Dictionary()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"i64[]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,x);x=Nonemul_1:"f32[3, 2]"=torch.ops.aten.mul.Tensor(y,mul);y=mul=Nonereturn(mul_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsmul_1:USER_OUTPUTRangeconstraints:{}
dynamic_shape_assert#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeAssert(torch.nn.Module):""" A basic usage of python assertion. """defforward(self,x):# assertion with error messageassertx.shape[0]>2,f"{x.shape[0]} is greater than 2"# assertion without error messageassertx.shape[0]>1returnxexample_args=(torch.randn(3,2),)tags={"python.assert"}model=DynamicShapeAssert()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):return(x,)Graphsignature:# inputsx:USER_INPUT# outputsx:USER_OUTPUTRangeconstraints:{}
dynamic_shape_constructor#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeConstructor(torch.nn.Module):""" Tensor constructors should be captured with dynamic shape inputs rather than being baked in with static shape. """defforward(self,x):returntorch.zeros(x.shape[0]*2)example_args=(torch.randn(3,2),)tags={"torch.dynamic-shape"}model=DynamicShapeConstructor()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):zeros:"f32[6]"=torch.ops.aten.zeros.default([6],device=device(type='cpu'),pin_memory=False)return(zeros,)Graphsignature:# inputsx:USER_INPUT# outputszeros:USER_OUTPUTRangeconstraints:{}
dynamic_shape_if_guard#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeIfGuard(torch.nn.Module):""" `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """defforward(self,x):ifx.shape[0]==3:returnx.cos()returnx.sin()example_args=(torch.randn(3,2,2),)tags={"torch.dynamic-shape","python.control-flow"}model=DynamicShapeIfGuard()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):cos:"f32[3, 2, 2]"=torch.ops.aten.cos.default(x);x=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}
dynamic_shape_map#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportmapclassDynamicShapeMap(torch.nn.Module):""" functorch map() maps a function over the first tensor dimension. """defforward(self,xs,y):defbody(x,y):returnx+yreturnmap(body,xs,y)example_args=(torch.randn(3,2),torch.randn(2))tags={"torch.dynamic-shape","torch.map"}model=DynamicShapeMap()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,xs:"f32[3, 2]",y:"f32[2]"):body_graph_0=self.body_graph_0map_impl=torch.ops.higher_order.map_impl(body_graph_0,[xs],[y]);body_graph_0=xs=y=Nonegetitem:"f32[3, 2]"=map_impl[0];map_impl=Nonereturn(getitem,)classbody_graph_0(torch.nn.Module):defforward(self,xs:"f32[2]",y:"f32[2]"):add:"f32[2]"=torch.ops.aten.add.Tensor(xs,y);xs=y=Nonereturn(add,)Graphsignature:# inputsxs:USER_INPUTy:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}
dynamic_shape_slicing#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeSlicing(torch.nn.Module):""" Slices with dynamic shape arguments should be captured into the graph rather than being baked in. """defforward(self,x):returnx[:x.shape[0]-2,x.shape[1]-1::2]example_args=(torch.randn(3,2),)tags={"torch.dynamic-shape"}model=DynamicShapeSlicing()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):slice_1:"f32[1, 2]"=torch.ops.aten.slice.Tensor(x,0,0,1);x=Noneslice_2:"f32[1, 1]"=torch.ops.aten.slice.Tensor(slice_1,1,1,9223372036854775807,2);slice_1=Nonereturn(slice_2,)Graphsignature:# inputsx:USER_INPUT# outputsslice_2:USER_OUTPUTRangeconstraints:{}
dynamic_shape_view#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeView(torch.nn.Module):""" Dynamic shapes should be propagated to view arguments instead of being baked into the exported graph. """defforward(self,x):new_x_shape=x.size()[:-1]+(2,5)x=x.view(*new_x_shape)returnx.permute(0,2,1)example_args=(torch.randn(10,10),)tags={"torch.dynamic-shape"}model=DynamicShapeView()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[10, 10]"):view:"f32[10, 2, 5]"=torch.ops.aten.view.default(x,[10,2,5]);x=Nonepermute:"f32[10, 5, 2]"=torch.ops.aten.permute.default(view,[0,2,1]);view=Nonereturn(permute,)Graphsignature:# inputsx:USER_INPUT# outputspermute:USER_OUTPUTRangeconstraints:{}
fn_with_kwargs#
Original source code:
# mypy: allow-untyped-defsimporttorchclassFnWithKwargs(torch.nn.Module):""" Keyword arguments are not supported at the moment. """defforward(self,pos0,tuple0,*myargs,mykw0,**mykwargs):out=pos0forargintuple0:out=out*argforarginmyargs:out=out*argout=out*mykw0out=out*mykwargs["input0"]*mykwargs["input1"]returnoutexample_args=(torch.randn(4),(torch.randn(4),torch.randn(4)),*[torch.randn(4),torch.randn(4)])example_kwargs={"mykw0":torch.randn(4),"input0":torch.randn(4),"input1":torch.randn(4),}tags={"python.data-structure"}model=FnWithKwargs()torch.export.export(model,example_args,example_kwargs)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pos0:"f32[4]",tuple0_0:"f32[4]",tuple0_1:"f32[4]",myargs_0:"f32[4]",myargs_1:"f32[4]",mykw0:"f32[4]",input0:"f32[4]",input1:"f32[4]"):mul:"f32[4]"=torch.ops.aten.mul.Tensor(pos0,tuple0_0);pos0=tuple0_0=Nonemul_1:"f32[4]"=torch.ops.aten.mul.Tensor(mul,tuple0_1);mul=tuple0_1=Nonemul_2:"f32[4]"=torch.ops.aten.mul.Tensor(mul_1,myargs_0);mul_1=myargs_0=Nonemul_3:"f32[4]"=torch.ops.aten.mul.Tensor(mul_2,myargs_1);mul_2=myargs_1=Nonemul_4:"f32[4]"=torch.ops.aten.mul.Tensor(mul_3,mykw0);mul_3=mykw0=Nonemul_5:"f32[4]"=torch.ops.aten.mul.Tensor(mul_4,input0);mul_4=input0=Nonemul_6:"f32[4]"=torch.ops.aten.mul.Tensor(mul_5,input1);mul_5=input1=Nonereturn(mul_6,)Graphsignature:# inputspos0:USER_INPUTtuple0_0:USER_INPUTtuple0_1:USER_INPUTmyargs_0:USER_INPUTmyargs_1:USER_INPUTmykw0:USER_INPUTinput0:USER_INPUTinput1:USER_INPUT# outputsmul_6:USER_OUTPUTRangeconstraints:{}
list_contains#
Original source code:
# mypy: allow-untyped-defsimporttorchclassListContains(torch.nn.Module):""" List containment relation can be checked on a dynamic shape or constants. """defforward(self,x):assertx.size(-1)in[6,2]assertx.size(0)notin[4,5,6]assert"monkey"notin["cow","pig"]returnx+xexample_args=(torch.randn(3,2),)tags={"torch.dynamic-shape","python.data-structure","python.assert"}model=ListContains()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,x);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
list_unpack#
Original source code:
# mypy: allow-untyped-defsimporttorchclassListUnpack(torch.nn.Module):""" Lists are treated as static construct, therefore unpacking should be erased after tracing. """defforward(self,args:list[torch.Tensor]):""" Lists are treated as static construct, therefore unpacking should be erased after tracing. """x,*y=argsreturnx+y[0]example_args=([torch.randn(3,2),torch.tensor(4),torch.tensor(5)],)tags={"python.control-flow","python.data-structure"}model=ListUnpack()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,args_0:"f32[3, 2]",args_1:"i64[]",args_2:"i64[]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(args_0,args_1);args_0=args_1=Nonereturn(add,)Graphsignature:# inputsargs_0:USER_INPUTargs_1:USER_INPUTargs_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
nested_function#
Original source code:
# mypy: allow-untyped-defsimporttorchclassNestedFunction(torch.nn.Module):""" Nested functions are traced through. Side effects on global captures are not supported though. """defforward(self,a,b):x=a+bz=a-bdefclosure(y):nonlocalxx+=1returnx*y+zreturnclosure(x)example_args=(torch.randn(3,2),torch.randn(2))tags={"python.closure"}model=NestedFunction()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,a:"f32[3, 2]",b:"f32[2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(a,b)sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(a,b);a=b=Noneadd_:"f32[3, 2]"=torch.ops.aten.add_.Tensor(add,1);add=Nonemul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add_,add_);add_=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,sub);mul=sub=Nonereturn(add_1,)Graphsignature:# inputsa:USER_INPUTb:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}
null_context_manager#
Original source code:
# mypy: allow-untyped-defsimportcontextlibimporttorchclassNullContextManager(torch.nn.Module):""" Null context manager in Python will be traced out. """defforward(self,x):""" Null context manager in Python will be traced out. """ctx=contextlib.nullcontext()withctx:returnx.sin()+x.cos()example_args=(torch.randn(3,2),)tags={"python.context-manager"}model=NullContextManager()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sin:"f32[3, 2]"=torch.ops.aten.sin.default(x)cos:"f32[3, 2]"=torch.ops.aten.cos.default(x);x=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(sin,cos);sin=cos=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
pytree_flatten#
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch.utilsimport_pytreeaspytreeclassPytreeFlatten(torch.nn.Module):""" Pytree from PyTorch can be captured by TorchDynamo. """defforward(self,x):y,_spec=pytree.tree_flatten(x)returny[0]+1example_args=({1:torch.randn(3,2),2:torch.randn(3,2)},),model=PytreeFlatten()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x_0_1:"f32[3, 2]",x_0_2:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x_0_1,1);x_0_1=Nonereturn(add,)Graphsignature:# inputsx_0_1:USER_INPUTx_0_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
scalar_output#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch.exportimportDimx=torch.randn(3,2)dim1_x=Dim("dim1_x")classScalarOutput(torch.nn.Module):""" Returning scalar values from the graph is supported, in addition to Tensor outputs. Symbolic shapes are captured and rank is specialized. """def__init__(self)->None:super().__init__()defforward(self,x):returnx.shape[1]+1example_args=(x,)tags={"torch.dynamic-shape"}dynamic_shapes={"x":{1:dim1_x}}model=ScalarOutput()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, s27]"):#sym_size_int_1:"Sym(s27)"=torch.ops.aten.sym_size.int(x,1);x=Noneadd:"Sym(s27 + 1)"=sym_size_int_1+1;sym_size_int_1=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{s27:VR[0,int_oo]}
specialized_attribute#
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defsfromenumimportEnumimporttorchclassAnimal(Enum):COW="moo"classSpecializedAttribute(torch.nn.Module):""" Model attributes are specialized. """def__init__(self)->None:super().__init__()self.a="moo"self.b=4defforward(self,x):ifself.a==Animal.COW.value:returnx*x+self.belse:raiseValueError("bad")example_args=(torch.randn(3,2),)model=SpecializedAttribute()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,x);x=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,4);mul=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
static_for_loop#
Original source code:
# mypy: allow-untyped-defsimporttorchclassStaticForLoop(torch.nn.Module):""" A for loop with constant number of iterations should be unrolled in the exported graph. """defforward(self,x):# constantret=[i+xforiinrange(10)]returnretexample_args=(torch.randn(3,2),)tags={"python.control-flow"}model=StaticForLoop()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,0)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1)add_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,2)add_3:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,3)add_4:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4)add_5:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,5)add_6:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,6)add_7:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,7)add_8:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,8)add_9:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,9);x=Nonereturn(add,add_1,add_2,add_3,add_4,add_5,add_6,add_7,add_8,add_9)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTadd_1:USER_OUTPUTadd_2:USER_OUTPUTadd_3:USER_OUTPUTadd_4:USER_OUTPUTadd_5:USER_OUTPUTadd_6:USER_OUTPUTadd_7:USER_OUTPUTadd_8:USER_OUTPUTadd_9:USER_OUTPUTRangeconstraints:{}
static_if#
Original source code:
# mypy: allow-untyped-defsimporttorchclassStaticIf(torch.nn.Module):""" `if` statement with static predicate value should be traced through with the taken branch. """defforward(self,x):iflen(x.shape)==3:returnx+torch.ones(1,1,1)returnxexample_args=(torch.randn(3,2,2),)tags={"python.control-flow"}model=StaticIf()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):ones:"f32[1, 1, 1]"=torch.ops.aten.ones.default([1,1,1],device=device(type='cpu'),pin_memory=False)add:"f32[3, 2, 2]"=torch.ops.aten.add.Tensor(x,ones);x=ones=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
tensor_setattr#
Original source code:
# mypy: allow-untyped-defsimporttorchclassTensorSetattr(torch.nn.Module):""" setattr() call onto tensors is not supported. """defforward(self,x,attr):setattr(x,attr,torch.randn(3,2))returnx+4example_args=(torch.randn(3,2),"attr")tags={"python.builtin"}model=TensorSetattr()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",attr):randn:"f32[3, 2]"=torch.ops.aten.randn.default([3,2],device=device(type='cpu'),pin_memory=False);randn=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUTattr:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
type_reflection_method#
Original source code:
# mypy: allow-untyped-defsimporttorchclassA:@classmethoddeffunc(cls,x):return1+xclassTypeReflectionMethod(torch.nn.Module):""" type() calls on custom objects followed by attribute accesses are not allowed due to its overly dynamic nature. """defforward(self,x):a=A()returntype(a).func(x)example_args=(torch.randn(3,4),)tags={"python.builtin"}model=TypeReflectionMethod()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 4]"):add:"f32[3, 4]"=torch.ops.aten.add.Tensor(x,1);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
user_input_mutation#
Original source code:
# mypy: allow-untyped-defsimporttorchclassUserInputMutation(torch.nn.Module):""" Directly mutate user input in forward """defforward(self,x):x.mul_(2)returnx.cos()example_args=(torch.randn(3,2),)tags={"torch.mutation"}model=UserInputMutation()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul_:"f32[3, 2]"=torch.ops.aten.mul_.Tensor(x,2);x=Nonecos:"f32[3, 2]"=torch.ops.aten.cos.default(mul_);mul_=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}
Not Supported Yet#
dynamic_shape_round#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelfromtorch.exportimportDimclassDynamicShapeRound(torch.nn.Module):""" Calling round on dynamic shapes is not supported. """defforward(self,x):returnx[:round(x.shape[0]/2)]x=torch.randn(3,2)dim0_x=Dim("dim0_x")example_args=(x,)tags={"torch.dynamic-shape","python.builtin"}support_level=SupportLevel.NOT_SUPPORTED_YETdynamic_shapes={"x":{0:dim0_x}}model=DynamicShapeRound()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)
Result:
Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".
model_attr_mutation#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassModelAttrMutation(torch.nn.Module):""" Attribute mutation is not supported. """def__init__(self)->None:super().__init__()self.attr_list=[torch.randn(3,2),torch.randn(3,2)]defrecreate_list(self):return[torch.zeros(3,2),torch.zeros(3,2)]defforward(self,x):self.attr_list=self.recreate_list()returnx.sum()+self.attr_list[0].sum()example_args=(torch.randn(3,2),)tags={"python.object-model"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=ModelAttrMutation()torch.export.export(model,example_args)
Result:
AssertionError:Mutatingmoduleattributeattr_listduringexport.
optional_input#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassOptionalInput(torch.nn.Module):""" Tracing through optional input is not supported yet """defforward(self,x,y=torch.randn(2,3)):ifyisnotNone:returnx+yreturnxexample_args=(torch.randn(2,3),)tags={"python.object-model"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=OptionalInput()torch.export.export(model,example_args)
Result:
Unsupported:Tracingthroughoptionalinputisnotsupportedyet
unsupported_operator#
Original source code:
# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassTorchSymMin(torch.nn.Module):""" torch.sym_min operator is not supported in export. """defforward(self,x):returnx.sum()+torch.sym_min(x.size(0),100)example_args=(torch.randn(3,2),)tags={"torch.operator"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=TorchSymMin()torch.export.export(model,example_args)
Result:
Unsupported:torch.*opreturnednon-Tensor