Rate this Page
★★★★★
torch.mutation#
user_input_mutation#
Original source code:
# mypy: allow-untyped-defsimporttorchclassUserInputMutation(torch.nn.Module):""" Directly mutate user input in forward """defforward(self,x):x.mul_(2)returnx.cos()example_args=(torch.randn(3,2),)tags={"torch.mutation"}model=UserInputMutation()torch.export.export(model,example_args)
Result:
ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul_:"f32[3, 2]"=torch.ops.aten.mul_.Tensor(x,2);x=Nonecos:"f32[3, 2]"=torch.ops.aten.cos.default(mul_);mul_=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}
On this page