Rate this Page
★★★★★
python.control-flow#
dynamic_shape_if_guard#
Original source code:
# mypy: allow-untyped-defsimporttorchclassDynamicShapeIfGuard(torch.nn.Module):""" `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """defforward(self,x):ifx.shape[0]==3:returnx.cos()returnx.sin()example_args=(torch.randn(3,2,2),)tags={"torch.dynamic-shape","python.control-flow"}model=DynamicShapeIfGuard()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):cos:"f32[3, 2, 2]"=torch.ops.aten.cos.default(x);x=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}
list_unpack#
Original source code:
# mypy: allow-untyped-defsimporttorchclassListUnpack(torch.nn.Module):""" Lists are treated as static construct, therefore unpacking should be erased after tracing. """defforward(self,args:list[torch.Tensor]):""" Lists are treated as static construct, therefore unpacking should be erased after tracing. """x,*y=argsreturnx+y[0]example_args=([torch.randn(3,2),torch.tensor(4),torch.tensor(5)],)tags={"python.control-flow","python.data-structure"}model=ListUnpack()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,args_0:"f32[3, 2]",args_1:"i64[]",args_2:"i64[]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(args_0,args_1);args_0=args_1=Nonereturn(add,)Graphsignature:# inputsargs_0:USER_INPUTargs_1:USER_INPUTargs_2:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
static_for_loop#
Original source code:
# mypy: allow-untyped-defsimporttorchclassStaticForLoop(torch.nn.Module):""" A for loop with constant number of iterations should be unrolled in the exported graph. """defforward(self,x):# constantret=[i+xforiinrange(10)]returnretexample_args=(torch.randn(3,2),)tags={"python.control-flow"}model=StaticForLoop()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,0)add_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1)add_2:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,2)add_3:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,3)add_4:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,4)add_5:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,5)add_6:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,6)add_7:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,7)add_8:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,8)add_9:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,9);x=Nonereturn(add,add_1,add_2,add_3,add_4,add_5,add_6,add_7,add_8,add_9)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTadd_1:USER_OUTPUTadd_2:USER_OUTPUTadd_3:USER_OUTPUTadd_4:USER_OUTPUTadd_5:USER_OUTPUTadd_6:USER_OUTPUTadd_7:USER_OUTPUTadd_8:USER_OUTPUTadd_9:USER_OUTPUTRangeconstraints:{}
static_if#
Original source code:
# mypy: allow-untyped-defsimporttorchclassStaticIf(torch.nn.Module):""" `if` statement with static predicate value should be traced through with the taken branch. """defforward(self,x):iflen(x.shape)==3:returnx+torch.ones(1,1,1)returnxexample_args=(torch.randn(3,2,2),)tags={"python.control-flow"}model=StaticIf()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2, 2]"):ones:"f32[1, 1, 1]"=torch.ops.aten.ones.default([1,1,1],device=device(type='cpu'),pin_memory=False)add:"f32[3, 2, 2]"=torch.ops.aten.add.Tensor(x,ones);x=ones=Nonereturn(add,)Graphsignature:# inputsx:USER_INPUT# outputsadd:USER_OUTPUTRangeconstraints:{}
On this page