Rate this Page

ExportDB#

ExportDB is a centralized dataset of supported and unsupported export cases.It is targeted towards users who want to understand specifically what types ofcode are supported, the subtleties of export, and how to modify their existingcode to be compatible with export. Note that this is not an exhaustive set ofeverything that is supported by exportdb, but it covers themost common and confusing use cases that users will run into.

If you have a feature that you think needs a stronger guarantee from us tosupport in export please create an issue in the pytorch/pytorch repo with a module:export tag.

Supported#

assume_constant_result#

Note

Tags:torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchimporttorch._dynamoastorchdynamoclassAssumeConstantResult(torch.nn.Module):"""    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.    """@torchdynamo.assume_constant_resultdefget_item(self,y):returny.int().item()defforward(self,x,y):returnx[:self.get_item(y)]example_args=(torch.randn(3,2),torch.tensor(4))tags={"torch.escape-hatch"}model=AssumeConstantResult()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"i64[]"):slice_1:"f32[3, 2]"=torch.ops.aten.slice.Tensor(x,0,0,4);x=Nonereturn(slice_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsslice_1:USER_OUTPUTRangeconstraints:{}

autograd_function#

Note

Tags:

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassMyAutogradFunction(torch.autograd.Function):@staticmethoddefforward(ctx,x):returnx.clone()@staticmethoddefbackward(ctx,grad_output):returngrad_output+1classAutogradFunction(torch.nn.Module):"""    TorchDynamo does not keep track of backward() on autograd functions. We recommend to    use `allow_in_graph` to mitigate this problem.    """defforward(self,x):returnMyAutogradFunction.apply(x)example_args=(torch.randn(3,2),)model=AutogradFunction()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):clone:"f32[3, 2]"=torch.ops.aten.clone.default(x);x=Nonereturn(clone,)Graphsignature:# inputsx:USER_INPUT# outputsclone:USER_OUTPUTRangeconstraints:{}

class_method#

Note

Tags:

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassClassMethod(torch.nn.Module):"""    Class methods are inlined during tracing.    """@classmethoddefmethod(cls,x):returnx+1def__init__(self)->None:super().__init__()self.linear=torch.nn.Linear(4,2)defforward(self,x):x=self.linear(x)returnself.method(x)*self.__class__.method(x)*type(self).method(x)example_args=(torch.randn(3,4),)model=ClassMethod()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,p_linear_weight:"f32[2, 4]",p_linear_bias:"f32[2]",x:"f32[3, 4]"):linear:"f32[3, 2]"=torch.ops.aten.linear.default(x,p_linear_weight,p_linear_bias);x=p_linear_weight=p_linear_bias=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1)mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add,add_1);add=add_1=Noneadd_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(linear,1);linear=Nonemul_1:"f32[3, 2]"=torch.ops.aten.mul.Tensor(mul,add_2);mul=add_2=Nonereturn(mul_1,)Graphsignature:# inputsp_linear_weight:PARAMETERtarget='linear.weight'p_linear_bias:PARAMETERtarget='linear.bias'x:USER_INPUT# outputsmul_1:USER_OUTPUTRangeconstraints:{}

cond_branch_class_method#

Note

Tags:torch.dynamic-shape,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassMySubModule(torch.nn.Module):deffoo(self,x):returnx.cos()defforward(self,x):returnself.foo(x)classCondBranchClassMethod(torch.nn.Module):"""    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:      - both branches must take the same args, which must also match the branch args passed to cond.      - both branches must return a single tensor      - returned tensor must have the same tensor metadata, e.g. shape and dtype      - branch function can be free function, nested function, lambda, class methods      - branch function can not have closure variables      - no inplace mutations on inputs or global variables    This example demonstrates using class method in cond().    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.    """def__init__(self)->None:super().__init__()self.subm=MySubModule()defbar(self,x):returnx.sin()defforward(self,x):returncond(x.shape[0]<=2,self.subm.forward,self.bar,[x])example_args=(torch.randn(3),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchClassMethod()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3]"):sin:"f32[3]"=torch.ops.aten.sin.default(x);x=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{}

cond_branch_nested_function#

Note

Tags:torch.dynamic-shape,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondBranchNestedFunction(torch.nn.Module):"""    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:      - both branches must take the same args, which must also match the branch args passed to cond.      - both branches must return a single tensor      - returned tensor must have the same tensor metadata, e.g. shape and dtype      - branch function can be free function, nested function, lambda, class methods      - branch function can not have closure variables      - no inplace mutations on inputs or global variables    This example demonstrates using nested function in cond().    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.    """defforward(self,x):deftrue_fn(x):definner_true_fn(y):returnx+yreturninner_true_fn(x)deffalse_fn(x):definner_false_fn(y):returnx-yreturninner_false_fn(x)returncond(x.shape[0]<10,true_fn,false_fn,[x])example_args=(torch.randn(3),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchNestedFunction()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3]"):add:"f32[3]"=torch.ops.aten.add.Tensor(x,x);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

cond_branch_nonlocal_variables#

Note

Tags:torch.dynamic-shape,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondBranchNonlocalVariables(torch.nn.Module):"""    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:    - both branches must take the same args, which must also match the branch args passed to cond.    - both branches must return a single tensor    - returned tensor must have the same tensor metadata, e.g. shape and dtype    - branch function can be free function, nested function, lambda, class methods    - branch function can not have closure variables    - no inplace mutations on inputs or global variables    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.    The code below will not work because capturing closure variables is not supported.    ```    my_tensor_var = x + 100    my_primitive_var = 3.14    def true_fn(y):        nonlocal my_tensor_var, my_primitive_var        return y + my_tensor_var + my_primitive_var    def false_fn(y):        nonlocal my_tensor_var, my_primitive_var        return y - my_tensor_var - my_primitive_var    return cond(x.shape[0] > 5, true_fn, false_fn, [x])    ```    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.    """defforward(self,x):my_tensor_var=x+100my_primitive_var=3.14deftrue_fn(x,y,z):returnx+y+zdeffalse_fn(x,y,z):returnx-y-zreturncond(x.shape[0]>5,true_fn,false_fn,[x,my_tensor_var,torch.tensor(my_primitive_var)],)example_args=(torch.randn(6),)tags={"torch.cond","torch.dynamic-shape",}model=CondBranchNonlocalVariables()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,c_lifted_tensor_0:"f32[]",x:"f32[6]"):add:"f32[6]"=torch.ops.aten.add.Tensor(x,100)lift_fresh_copy:"f32[]"=torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);c_lifted_tensor_0=Nonedetach_:"f32[]"=torch.ops.aten.detach_.default(lift_fresh_copy);lift_fresh_copy=Noneadd_1:"f32[6]"=torch.ops.aten.add.Tensor(x,add);x=add=Noneadd_2:"f32[6]"=torch.ops.aten.add.Tensor(add_1,detach_);add_1=detach_=Nonereturn(add_2,)Graphsignature:# inputsc_lifted_tensor_0:CONSTANT_TENSORtarget='lifted_tensor_0'x:USER_INPUT# outputsadd_2:USER_OUTPUTRangeconstraints:{}

cond_closed_over_variable#

Note

Tags:python.closure,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondClosedOverVariable(torch.nn.Module):"""    torch.cond() supports branches closed over arbitrary variables.    """defforward(self,pred,x):deftrue_fn(val):returnx*2deffalse_fn(val):returnx-2returncond(pred,true_fn,false_fn,[x+1])example_args=(torch.tensor(True),torch.randn(3,2))tags={"torch.cond","python.closure"}model=CondClosedOverVariable()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pred:"b8[]",x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1);add=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(pred,true_graph_0,false_graph_0,(x,));pred=true_graph_0=false_graph_0=x=Nonegetitem:"f32[3, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,2);x=Nonereturn(mul,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(x,2);x=Nonereturn(sub,)Graphsignature:# inputspred:USER_INPUTx:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}

cond_operands#

Note

Tags:torch.dynamic-shape,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch.exportimportDimx=torch.randn(3,2)y=torch.randn(2)dim0_x=Dim("dim0_x")classCondOperands(torch.nn.Module):"""    The operands passed to cond() must be:    - a list of tensors    - match arguments of `true_fn` and `false_fn`    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.    """defforward(self,x,y):deftrue_fn(x,y):returnx+ydeffalse_fn(x,y):returnx-yreturntorch.cond(x.shape[0]>2,true_fn,false_fn,[x,y])example_args=(x,y)tags={"torch.cond","torch.dynamic-shape",}extra_inputs=(torch.randn(2,2),torch.randn(2))dynamic_shapes={"x":{0:dim0_x},"y":None}model=CondOperands()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):#sym_size_int_1:"Sym(s77)"=torch.ops.aten.sym_size.int(x,0)gt:"Sym(s77 > 2)"=sym_size_int_1>2;sym_size_int_1=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(gt,true_graph_0,false_graph_0,(x,y));gt=true_graph_0=false_graph_0=x=y=Nonegetitem:"f32[s77, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):add:"f32[s77, 2]"=torch.ops.aten.add.Tensor(x,y);x=y=Nonereturn(add,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[s77, 2]",y:"f32[2]"):sub:"f32[s77, 2]"=torch.ops.aten.sub.Tensor(x,y);x=y=Nonereturn(sub,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{s77:VR[0,int_oo]}

cond_predicate#

Note

Tags:torch.dynamic-shape,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondPredicate(torch.nn.Module):"""    The conditional statement (aka predicate) passed to cond() must be one of the following:      - torch.Tensor with a single element      - boolean expression    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.    """defforward(self,x):pred=x.dim()>2andx.shape[2]>10returncond(pred,lambdax:x.cos(),lambday:y.sin(),[x])example_args=(torch.randn(6,4,3),)tags={"torch.cond","torch.dynamic-shape",}model=CondPredicate()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[6, 4, 3]"):sin:"f32[6, 4, 3]"=torch.ops.aten.sin.default(x);x=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{}

constrain_as_size_example#

Note

Tags:torch.dynamic-value,torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassConstrainAsSizeExample(torch.nn.Module):"""    If the value is not known at tracing time, you can provide hint so that we    can trace further. Please look at torch._check and torch._check_is_size APIs.    torch._check_is_size is used for values that NEED to be used for constructing    tensor.    """defforward(self,x):a=x.item()torch._check_is_size(a)torch._check(a<=5)returntorch.zeros((a,5))example_args=(torch.tensor(4),)tags={"torch.dynamic-value","torch.escape-hatch",}model=ConstrainAsSizeExample()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"i64[]"):item:"Sym(u0)"=torch.ops.aten.item.default(x);x=None#sym_constrain_range_for_size_default=torch.ops.aten.sym_constrain_range_for_size.default(item);sym_constrain_range_for_size_default=Nonege_1:"Sym(u0 >= 0)"=item>=0_assert_scalar_default=torch.ops.aten._assert_scalar.default(ge_1,"Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");ge_1=_assert_scalar_default=Nonele_1:"Sym(u0 <= 5)"=item<=5_assert_scalar_default_1=torch.ops.aten._assert_scalar.default(le_1,"Runtime assertion failed for expression u0 <= 5 on node 'le_1'");le_1=_assert_scalar_default_1=Nonezeros:"f32[u0, 5]"=torch.ops.aten.zeros.default([item,5],device=device(type='cpu'),pin_memory=False);item=Nonereturn(zeros,)Graphsignature:# inputsx:USER_INPUT# outputszeros:USER_OUTPUTRangeconstraints:{u0:VR[0,5],u1:VR[0,5]}

constrain_as_value_example#

Note

Tags:torch.dynamic-value,torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassConstrainAsValueExample(torch.nn.Module):"""    If the value is not known at tracing time, you can provide hint so that we    can trace further. Please look at torch._check and torch._check_is_size APIs.    torch._check is used for values that don't need to be used for constructing    tensor.    """defforward(self,x,y):a=x.item()torch._check(a>=0)torch._check(a<=5)ifa<6:returny.sin()returny.cos()example_args=(torch.tensor(4),torch.randn(5,5))tags={"torch.dynamic-value","torch.escape-hatch",}model=ConstrainAsValueExample()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"i64[]",y:"f32[5, 5]"):item:"Sym(u0)"=torch.ops.aten.item.default(x);x=Nonege_1:"Sym(u0 >= 0)"=item>=0_assert_scalar_default=torch.ops.aten._assert_scalar.default(ge_1,"Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");ge_1=_assert_scalar_default=Nonele_1:"Sym(u0 <= 5)"=item<=5;item=None_assert_scalar_default_1=torch.ops.aten._assert_scalar.default(le_1,"Runtime assertion failed for expression u0 <= 5 on node 'le_1'");le_1=_assert_scalar_default_1=Nonesin:"f32[5, 5]"=torch.ops.aten.sin.default(y);y=Nonereturn(sin,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputssin:USER_OUTPUTRangeconstraints:{u0:VR[0,5],u1:VR[0,5]}

decorator#

Note

Tags:

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimportfunctoolsimporttorchdeftest_decorator(func):@functools.wraps(func)defwrapper(*args,**kwargs):returnfunc(*args,**kwargs)+1returnwrapperclassDecorator(torch.nn.Module):"""    Decorators calls are inlined into the exported function during tracing.    """@test_decoratordefforward(self,x,y):returnx+yexample_args=(torch.randn(3,2),torch.randn(3,2))model=Decorator()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,y);x=y=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(add,1);add=Nonereturn(add_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}

dictionary#

Note

Tags:python.data-structure

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDictionary(torch.nn.Module):"""    Dictionary structures are inlined and flattened along tracing.    """defforward(self,x,y):elements={}elements["x2"]=x*xy=y*elements["x2"]return{"y":y}example_args=(torch.randn(3,2),torch.tensor(4))tags={"python.data-structure"}model=Dictionary()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",y:"i64[]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,x);x=Nonemul_1:"f32[3, 2]"=torch.ops.aten.mul.Tensor(y,mul);y=mul=Nonereturn(mul_1,)Graphsignature:# inputsx:USER_INPUTy:USER_INPUT# outputsmul_1:USER_OUTPUTRangeconstraints:{}

dynamic_shape_assert#

Note

Tags:python.assert

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeAssert(torch.nn.Module):"""    A basic usage of python assertion.    """defforward(self,x):# assertion with error messageassertx.shape[0]>2,f"{x.shape[0]} is greater than 2"# assertion without error messageassertx.shape[0]>1returnxexample_args=(torch.randn(3,2),)tags={"python.assert"}model=DynamicShapeAssert()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):return(x,)Graphsignature:# inputsx:USER_INPUT# outputsx:USER_OUTPUTRangeconstraints:{}

dynamic_shape_constructor#

Note

Tags:torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeConstructor(torch.nn.Module):"""    Tensor constructors should be captured with dynamic shape inputs rather    than being baked in with static shape.    """defforward(self,x):returntorch.zeros(x.shape[0]*2)example_args=(torch.randn(3,2),)tags={"torch.dynamic-shape"}model=DynamicShapeConstructor()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):zeros:"f32[6]"=torch.ops.aten.zeros.default([6],device=device(type='cpu'),pin_memory=False)return(zeros,)Graphsignature:# inputsx:USER_INPUT# outputszeros:USER_OUTPUTRangeconstraints:{}

dynamic_shape_if_guard#

Note

Tags:torch.dynamic-shape,python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeIfGuard(torch.nn.Module):"""    `if` statement with backed dynamic shape predicate will be specialized into    one particular branch and generate a guard. However, export will fail if the    the dimension is marked as dynamic shape from higher level API.    """defforward(self,x):ifx.shape[0]==3:returnx.cos()returnx.sin()example_args=(torch.randn(3,2,2),)tags={"torch.dynamic-shape","python.control-flow"}model=DynamicShapeIfGuard()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):cos:"f32[3, 2, 2]"=torch.ops.aten.cos.default(x);x=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}

dynamic_shape_map#

Note

Tags:torch.dynamic-shape,torch.map

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportmapclassDynamicShapeMap(torch.nn.Module):"""    functorch map() maps a function over the first tensor dimension.    """defforward(self,xs,y):defbody(x,y):returnx+yreturnmap(body,xs,y)example_args=(torch.randn(3,2),torch.randn(2))tags={"torch.dynamic-shape","torch.map"}model=DynamicShapeMap()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,xs:"f32[3, 2]",y:"f32[2]"):body_graph_0=self.body_graph_0map_impl=torch.ops.higher_order.map_impl(body_graph_0,[xs],[y]);body_graph_0=xs=y=Nonegetitem:"f32[3, 2]"=map_impl[0];map_impl=Nonereturn(getitem,)classbody_graph_0(torch.nn.Module):defforward(self,xs:"f32[2]",y:"f32[2]"):add:"f32[2]"=torch.ops.aten.add.Tensor(xs,y);xs=y=Nonereturn(add,)Graphsignature:# inputsxs:USER_INPUTy:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}

dynamic_shape_slicing#

Note

Tags:torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeSlicing(torch.nn.Module):"""    Slices with dynamic shape arguments should be captured into the graph    rather than being baked in.    """defforward(self,x):returnx[:x.shape[0]-2,x.shape[1]-1::2]example_args=(torch.randn(3,2),)tags={"torch.dynamic-shape"}model=DynamicShapeSlicing()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):slice_1:"f32[1, 2]"=torch.ops.aten.slice.Tensor(x,0,0,1);x=Noneslice_2:"f32[1, 1]"=torch.ops.aten.slice.Tensor(slice_1,1,1,9223372036854775807,2);slice_1=Nonereturn(slice_2,)Graphsignature:# inputsx:USER_INPUT# outputsslice_2:USER_OUTPUTRangeconstraints:{}

dynamic_shape_view#

Note

Tags:torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeView(torch.nn.Module):"""    Dynamic shapes should be propagated to view arguments instead of being    baked into the exported graph.    """defforward(self,x):new_x_shape=x.size()[:-1]+(2,5)x=x.view(*new_x_shape)returnx.permute(0,2,1)example_args=(torch.randn(10,10),)tags={"torch.dynamic-shape"}model=DynamicShapeView()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[10, 10]"):view:"f32[10, 2, 5]"=torch.ops.aten.view.default(x,[10,2,5]);x=Nonepermute:"f32[10, 5, 2]"=torch.ops.aten.permute.default(view,[0,2,1]);view=Nonereturn(permute,)Graphsignature:# inputsx:USER_INPUT# outputspermute:USER_OUTPUTRangeconstraints:{}

fn_with_kwargs#

Note

Tags:python.data-structure

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassFnWithKwargs(torch.nn.Module):"""    Keyword arguments are not supported at the moment.    """defforward(self,pos0,tuple0,*myargs,mykw0,**mykwargs):out=pos0forargintuple0:out=out*argforarginmyargs:out=out*argout=out*mykw0out=out*mykwargs["input0"]*mykwargs["input1"]returnoutexample_args=(torch.randn(4),(torch.randn(4),torch.randn(4)),*[torch.randn(4),torch.randn(4)])example_kwargs={"mykw0":torch.randn(4),"input0":torch.randn(4),"input1":torch.randn(4),}tags={"python.data-structure"}model=FnWithKwargs()torch.export.export(model,example_args,example_kwargs)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pos0:"f32[4]",tuple0_0:"f32[4]",tuple0_1:"f32[4]",myargs_0:"f32[4]",myargs_1:"f32[4]",mykw0:"f32[4]",input0:"f32[4]",input1:"f32[4]"):mul:"f32[4]"=torch.ops.aten.mul.Tensor(pos0,tuple0_0);pos0=tuple0_0=Nonemul_1:"f32[4]"=torch.ops.aten.mul.Tensor(mul,tuple0_1);mul=tuple0_1=Nonemul_2:"f32[4]"=torch.ops.aten.mul.Tensor(mul_1,myargs_0);mul_1=myargs_0=Nonemul_3:"f32[4]"=torch.ops.aten.mul.Tensor(mul_2,myargs_1);mul_2=myargs_1=Nonemul_4:"f32[4]"=torch.ops.aten.mul.Tensor(mul_3,mykw0);mul_3=mykw0=Nonemul_5:"f32[4]"=torch.ops.aten.mul.Tensor(mul_4,input0);mul_4=input0=Nonemul_6:"f32[4]"=torch.ops.aten.mul.Tensor(mul_5,input1);mul_5=input1=Nonereturn(mul_6,)Graphsignature:# inputspos0:USER_INPUTtuple0_0:USER_INPUTtuple0_1:USER_INPUTmyargs_0:USER_INPUTmyargs_1:USER_INPUTmykw0:USER_INPUTinput0:USER_INPUTinput1:USER_INPUT# outputsmul_6:USER_OUTPUTRangeconstraints:{}

list_contains#

Original source code:

# mypy: allow-untyped-defsimporttorchclassListContains(torch.nn.Module):"""    List containment relation can be checked on a dynamic shape or constants.    """defforward(self,x):assertx.size(-1)in[6,2]assertx.size(0)notin[4,5,6]assert"monkey"notin["cow","pig"]returnx+xexample_args=(torch.randn(3,2),)tags={"torch.dynamic-shape","python.data-structure","python.assert"}model=ListContains()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,x);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

list_unpack#

Note

Tags:python.data-structure,python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassListUnpack(torch.nn.Module):"""    Lists are treated as static construct, therefore unpacking should be    erased after tracing.    """defforward(self,args:list[torch.Tensor]):"""        Lists are treated as static construct, therefore unpacking should be        erased after tracing.        """x,*y=argsreturnx+y[0]example_args=([torch.randn(3,2),torch.tensor(4),torch.tensor(5)],)tags={"python.control-flow","python.data-structure"}model=ListUnpack()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,args_0:"f32[3, 2]",args_1:"i64[]",args_2:"i64[]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(args_0,args_1);args_0=args_1=Nonereturn(add,)Graphsignature:# inputsargs_0:USER_INPUTargs_1:USER_INPUTargs_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

nested_function#

Note

Tags:python.closure

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassNestedFunction(torch.nn.Module):"""    Nested functions are traced through. Side effects on global captures    are not supported though.    """defforward(self,a,b):x=a+bz=a-bdefclosure(y):nonlocalxx+=1returnx*y+zreturnclosure(x)example_args=(torch.randn(3,2),torch.randn(2))tags={"python.closure"}model=NestedFunction()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,a:"f32[3, 2]",b:"f32[2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(a,b)sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(a,b);a=b=Noneadd_:"f32[3, 2]"=torch.ops.aten.add_.Tensor(add,1);add=Nonemul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add_,add_);add_=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,sub);mul=sub=Nonereturn(add_1,)Graphsignature:# inputsa:USER_INPUTb:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}

null_context_manager#

Note

Tags:python.context-manager

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimportcontextlibimporttorchclassNullContextManager(torch.nn.Module):"""    Null context manager in Python will be traced out.    """defforward(self,x):"""        Null context manager in Python will be traced out.        """ctx=contextlib.nullcontext()withctx:returnx.sin()+x.cos()example_args=(torch.randn(3,2),)tags={"python.context-manager"}model=NullContextManager()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sin:"f32[3, 2]"=torch.ops.aten.sin.default(x)cos:"f32[3, 2]"=torch.ops.aten.cos.default(x);x=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(sin,cos);sin=cos=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

pytree_flatten#

Note

Tags:

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch.utilsimport_pytreeaspytreeclassPytreeFlatten(torch.nn.Module):"""    Pytree from PyTorch can be captured by TorchDynamo.    """defforward(self,x):y,_spec=pytree.tree_flatten(x)returny[0]+1example_args=({1:torch.randn(3,2),2:torch.randn(3,2)},),model=PytreeFlatten()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x_0_1:"f32[3, 2]",x_0_2:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x_0_1,1);x_0_1=Nonereturn(add,)Graphsignature:# inputsx_0_1:USER_INPUTx_0_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

scalar_output#

Note

Tags:torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch.exportimportDimx=torch.randn(3,2)dim1_x=Dim("dim1_x")classScalarOutput(torch.nn.Module):"""    Returning scalar values from the graph is supported, in addition to Tensor    outputs. Symbolic shapes are captured and rank is specialized.    """def__init__(self)->None:super().__init__()defforward(self,x):returnx.shape[1]+1example_args=(x,)tags={"torch.dynamic-shape"}dynamic_shapes={"x":{1:dim1_x}}model=ScalarOutput()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, s27]"):#sym_size_int_1:"Sym(s27)"=torch.ops.aten.sym_size.int(x,1);x=Noneadd:"Sym(s27 + 1)"=sym_size_int_1+1;sym_size_int_1=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{s27:VR[0,int_oo]}

specialized_attribute#

Note

Tags:

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsfromenumimportEnumimporttorchclassAnimal(Enum):COW="moo"classSpecializedAttribute(torch.nn.Module):"""    Model attributes are specialized.    """def__init__(self)->None:super().__init__()self.a="moo"self.b=4defforward(self,x):ifself.a==Animal.COW.value:returnx*x+self.belse:raiseValueError("bad")example_args=(torch.randn(3,2),)model=SpecializedAttribute()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,x);x=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,4);mul=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

static_for_loop#

Note

Tags:python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassStaticForLoop(torch.nn.Module):"""    A for loop with constant number of iterations should be unrolled in the exported graph.    """defforward(self,x):# constantret=[i+xforiinrange(10)]returnretexample_args=(torch.randn(3,2),)tags={"python.control-flow"}model=StaticForLoop()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,0)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1)add_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,2)add_3:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,3)add_4:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4)add_5:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,5)add_6:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,6)add_7:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,7)add_8:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,8)add_9:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,9);x=Nonereturn(add,add_1,add_2,add_3,add_4,add_5,add_6,add_7,add_8,add_9)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTadd_1:USER_OUTPUTadd_2:USER_OUTPUTadd_3:USER_OUTPUTadd_4:USER_OUTPUTadd_5:USER_OUTPUTadd_6:USER_OUTPUTadd_7:USER_OUTPUTadd_8:USER_OUTPUTadd_9:USER_OUTPUTRangeconstraints:{}

static_if#

Note

Tags:python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassStaticIf(torch.nn.Module):"""    `if` statement with static predicate value should be traced through with the    taken branch.    """defforward(self,x):iflen(x.shape)==3:returnx+torch.ones(1,1,1)returnxexample_args=(torch.randn(3,2,2),)tags={"python.control-flow"}model=StaticIf()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):ones:"f32[1, 1, 1]"=torch.ops.aten.ones.default([1,1,1],device=device(type='cpu'),pin_memory=False)add:"f32[3, 2, 2]"=torch.ops.aten.add.Tensor(x,ones);x=ones=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

tensor_setattr#

Note

Tags:python.builtin

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassTensorSetattr(torch.nn.Module):"""    setattr() call onto tensors is not supported.    """defforward(self,x,attr):setattr(x,attr,torch.randn(3,2))returnx+4example_args=(torch.randn(3,2),"attr")tags={"python.builtin"}model=TensorSetattr()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]",attr):randn:"f32[3, 2]"=torch.ops.aten.randn.default([3,2],device=device(type='cpu'),pin_memory=False);randn=Noneadd:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUTattr:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

type_reflection_method#

Note

Tags:python.builtin

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassA:@classmethoddeffunc(cls,x):return1+xclassTypeReflectionMethod(torch.nn.Module):"""    type() calls on custom objects followed by attribute accesses are not allowed    due to its overly dynamic nature.    """defforward(self,x):a=A()returntype(a).func(x)example_args=(torch.randn(3,4),)tags={"python.builtin"}model=TypeReflectionMethod()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 4]"):add:"f32[3, 4]"=torch.ops.aten.add.Tensor(x,1);x=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

user_input_mutation#

Note

Tags:torch.mutation

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassUserInputMutation(torch.nn.Module):"""    Directly mutate user input in forward    """defforward(self,x):x.mul_(2)returnx.cos()example_args=(torch.randn(3,2),)tags={"torch.mutation"}model=UserInputMutation()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul_:"f32[3, 2]"=torch.ops.aten.mul_.Tensor(x,2);x=Nonecos:"f32[3, 2]"=torch.ops.aten.cos.default(mul_);mul_=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}

Not Supported Yet#

dynamic_shape_round#

Note

Tags:python.builtin,torch.dynamic-shape

Support Level: NOT_SUPPORTED_YET

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelfromtorch.exportimportDimclassDynamicShapeRound(torch.nn.Module):"""    Calling round on dynamic shapes is not supported.    """defforward(self,x):returnx[:round(x.shape[0]/2)]x=torch.randn(3,2)dim0_x=Dim("dim0_x")example_args=(x,)tags={"torch.dynamic-shape","python.builtin"}support_level=SupportLevel.NOT_SUPPORTED_YETdynamic_shapes={"x":{0:dim0_x}}model=DynamicShapeRound()torch.export.export(model,example_args,dynamic_shapes=dynamic_shapes)

Result:

Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".

model_attr_mutation#

Note

Tags:python.object-model

Support Level: NOT_SUPPORTED_YET

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassModelAttrMutation(torch.nn.Module):"""    Attribute mutation is not supported.    """def__init__(self)->None:super().__init__()self.attr_list=[torch.randn(3,2),torch.randn(3,2)]defrecreate_list(self):return[torch.zeros(3,2),torch.zeros(3,2)]defforward(self,x):self.attr_list=self.recreate_list()returnx.sum()+self.attr_list[0].sum()example_args=(torch.randn(3,2),)tags={"python.object-model"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=ModelAttrMutation()torch.export.export(model,example_args)

Result:

AssertionError:Mutatingmoduleattributeattr_listduringexport.

optional_input#

Note

Tags:python.object-model

Support Level: NOT_SUPPORTED_YET

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassOptionalInput(torch.nn.Module):"""    Tracing through optional input is not supported yet    """defforward(self,x,y=torch.randn(2,3)):ifyisnotNone:returnx+yreturnxexample_args=(torch.randn(2,3),)tags={"python.object-model"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=OptionalInput()torch.export.export(model,example_args)

Result:

Unsupported:Tracingthroughoptionalinputisnotsupportedyet

unsupported_operator#

Note

Tags:torch.operator

Support Level: NOT_SUPPORTED_YET

Original source code:

# mypy: allow-untyped-defsimporttorchfromtorch._export.db.caseimportSupportLevelclassTorchSymMin(torch.nn.Module):"""    torch.sym_min operator is not supported in export.    """defforward(self,x):returnx.sum()+torch.sym_min(x.size(0),100)example_args=(torch.randn(3,2),)tags={"torch.operator"}support_level=SupportLevel.NOT_SUPPORTED_YETmodel=TorchSymMin()torch.export.export(model,example_args)

Result:

Unsupported:torch.*opreturnednon-Tensor