Rate this Page

python.control-flow#

dynamic_shape_if_guard#

Note

Tags:torch.dynamic-shape,python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassDynamicShapeIfGuard(torch.nn.Module):"""    `if` statement with backed dynamic shape predicate will be specialized into    one particular branch and generate a guard. However, export will fail if the    the dimension is marked as dynamic shape from higher level API.    """defforward(self,x):ifx.shape[0]==3:returnx.cos()returnx.sin()example_args=(torch.randn(3,2,2),)tags={"torch.dynamic-shape","python.control-flow"}model=DynamicShapeIfGuard()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):cos:"f32[3, 2, 2]"=torch.ops.aten.cos.default(x);x=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}

list_unpack#

Note

Tags:python.data-structure,python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassListUnpack(torch.nn.Module):"""    Lists are treated as static construct, therefore unpacking should be    erased after tracing.    """defforward(self,args:list[torch.Tensor]):"""        Lists are treated as static construct, therefore unpacking should be        erased after tracing.        """x,*y=argsreturnx+y[0]example_args=([torch.randn(3,2),torch.tensor(4),torch.tensor(5)],)tags={"python.control-flow","python.data-structure"}model=ListUnpack()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,args_0:"f32[3, 2]",args_1:"i64[]",args_2:"i64[]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(args_0,args_1);args_0=args_1=Nonereturn(add,)Graphsignature:# inputsargs_0:USER_INPUTargs_1:USER_INPUTargs_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}

static_for_loop#

Note

Tags:python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassStaticForLoop(torch.nn.Module):"""    A for loop with constant number of iterations should be unrolled in the exported graph.    """defforward(self,x):# constantret=[i+xforiinrange(10)]returnretexample_args=(torch.randn(3,2),)tags={"python.control-flow"}model=StaticForLoop()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,0)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1)add_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,2)add_3:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,3)add_4:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4)add_5:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,5)add_6:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,6)add_7:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,7)add_8:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,8)add_9:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,9);x=Nonereturn(add,add_1,add_2,add_3,add_4,add_5,add_6,add_7,add_8,add_9)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTadd_1:USER_OUTPUTadd_2:USER_OUTPUTadd_3:USER_OUTPUTadd_4:USER_OUTPUTadd_5:USER_OUTPUTadd_6:USER_OUTPUTadd_7:USER_OUTPUTadd_8:USER_OUTPUTadd_9:USER_OUTPUTRangeconstraints:{}

static_if#

Note

Tags:python.control-flow

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassStaticIf(torch.nn.Module):"""    `if` statement with static predicate value should be traced through with the    taken branch.    """defforward(self,x):iflen(x.shape)==3:returnx+torch.ones(1,1,1)returnxexample_args=(torch.randn(3,2,2),)tags={"python.control-flow"}model=StaticIf()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):ones:"f32[1, 1, 1]"=torch.ops.aten.ones.default([1,1,1],device=device(type='cpu'),pin_memory=False)add:"f32[3, 2, 2]"=torch.ops.aten.add.Tensor(x,ones);x=ones=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}