Rate this Page

python.closure#

cond_closed_over_variable#

Note

Tags:python.closure,torch.cond

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportcondclassCondClosedOverVariable(torch.nn.Module):"""    torch.cond() supports branches closed over arbitrary variables.    """defforward(self,pred,x):deftrue_fn(val):returnx*2deffalse_fn(val):returnx-2returncond(pred,true_fn,false_fn,[x+1])example_args=(torch.tensor(True),torch.randn(3,2))tags={"torch.cond","python.closure"}model=CondClosedOverVariable()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,pred:"b8[]",x:"f32[3, 2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(x,1);add=Nonetrue_graph_0=self.true_graph_0false_graph_0=self.false_graph_0cond=torch.ops.higher_order.cond(pred,true_graph_0,false_graph_0,(x,));pred=true_graph_0=false_graph_0=x=Nonegetitem:"f32[3, 2]"=cond[0];cond=Nonereturn(getitem,)classtrue_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(x,2);x=Nonereturn(mul,)classfalse_graph_0(torch.nn.Module):defforward(self,x:"f32[3, 2]"):sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(x,2);x=Nonereturn(sub,)Graphsignature:# inputspred:USER_INPUTx:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}

nested_function#

Note

Tags:python.closure

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassNestedFunction(torch.nn.Module):"""    Nested functions are traced through. Side effects on global captures    are not supported though.    """defforward(self,a,b):x=a+bz=a-bdefclosure(y):nonlocalxx+=1returnx*y+zreturnclosure(x)example_args=(torch.randn(3,2),torch.randn(2))tags={"python.closure"}model=NestedFunction()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,a:"f32[3, 2]",b:"f32[2]"):add:"f32[3, 2]"=torch.ops.aten.add.Tensor(a,b)sub:"f32[3, 2]"=torch.ops.aten.sub.Tensor(a,b);a=b=Noneadd_:"f32[3, 2]"=torch.ops.aten.add_.Tensor(add,1);add=Nonemul:"f32[3, 2]"=torch.ops.aten.mul.Tensor(add_,add_);add_=Noneadd_1:"f32[3, 2]"=torch.ops.aten.add.Tensor(mul,sub);mul=sub=Nonereturn(add_1,)Graphsignature:# inputsa:USER_INPUTb:USER_INPUT# outputsadd_1:USER_OUTPUTRangeconstraints:{}