Rate this Page
★★★★★
python.closure#
cond_closed_over_variable#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondClosedOverVariable(torch.nn.Module):""" torch.cond() supports branches closed over arbitrary variables. """defforward(self,pred,x):deftrue_fn(val):returnx*2deffalse_fn(val):returnx-2returncond(pred,true_fn,false_fn,[x+1])example_args=(torch.tensor(True),torch.randn(3,2))tags={"torch.cond","python.closure"}model=CondClosedOverVariable()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pred:"b8[]",x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1);add=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(pred,true_graph_0,false_graph_0,(x,));pred=true_graph_0=false_graph_0=x=Nonegetitem:"f32[3, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,2);x=Nonereturn(mul,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(x,2);x=Nonereturn(sub,)Graphsignature:# inputspred:USER_INPUTx:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}
nested_function#
Original source code:
# mypy: allow-untyped-defsimporttorchclassNestedFunction(torch.nn.Module):""" Nested functions are traced through. Side effects on global captures are not supported though. """defforward(self,a,b):x=a+bz=a-bdefclosure(y):nonlocalxx+=1returnx*y+zreturnclosure(x)example_args=(torch.randn(3,2),torch.randn(2))tags={"python.closure"}model=NestedFunction()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,a:"f32[3, 2]",b:"f32[2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(a,b)sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(a,b);a=b=Noneadd_:"f32[3, 2]"=torch.ops.aten.add_.Tensor(add,1);add=Nonemul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add_,add_);add_=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,sub);mul=sub=Nonereturn(add_1,)Graphsignature:# inputsa:USER_INPUTb:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}
On this page