Rate this Page
★★★★★
torch.map#
dynamic_shape_map#
Original source code:
# mypy: allow-untyped-defsimporttorchfromfunctorch.experimental.control_flowimportmapclassDynamicShapeMap(torch.nn.Module):""" functorch map() maps a function over the first tensor dimension. """defforward(self,xs,y):defbody(x,y):returnx+yreturnmap(body,xs,y)example_args=(torch.randn(3,2),torch.randn(2))tags={"torch.dynamic-shape","torch.map"}model=DynamicShapeMap()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,xs:"f32[3, 2]",y:"f32[2]"):body_graph_0=self.body_graph_0map_impl=torch.ops.higher_order.map_impl(body_graph_0,[xs],[y]);body_graph_0=xs=y=Nonegetitem:"f32[3, 2]"=map_impl[0];map_impl=Nonereturn(getitem,)classbody_graph_0(torch.nn.Module):defforward(self,xs:"f32[2]",y:"f32[2]"):add:"f32[2]"=torch.ops.aten.add.Tensor(xs,y);xs=y=Nonereturn(add,)Graphsignature:# inputsxs:USER_INPUTy:USER_INPUT# outputsgetitem:USER_OUTPUTRangeconstraints:{}
On this page