Rate this Page

torch.mutation#

user_input_mutation#

Note

Tags:torch.mutation

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defsimporttorchclassUserInputMutation(torch.nn.Module):"""    Directly mutate user input in forward    """defforward(self,x):x.mul_(2)returnx.cos()example_args=(torch.randn(3,2),)tags={"torch.mutation"}model=UserInputMutation()torch.export.export(model,example_args)

Result:

ExportedProgram:classGraphModule(torch.nn.Module):defforward(self,x:"f32[3, 2]"):mul_:"f32[3, 2]"=torch.ops.aten.mul_.Tensor(x,2);x=Nonecos:"f32[3, 2]"=torch.ops.aten.cos.default(mul_);mul_=Nonereturn(cos,)Graphsignature:# inputsx:USER_INPUT# outputscos:USER_OUTPUTRangeconstraints:{}