Rate this Page

torch.func.functionalize#

torch.func.functionalize(func,*,remove='mutations')[source]#

functionalize is a transform that can be used to remove (intermediate)mutations and aliasing from a function, while preserving the function’ssemantics.

functionalize(func) returns a new function with the same semanticsasfunc, but with all intermediate mutations removed.Every inplace operation performed on an intermediate tensor:intermediate.foo_()gets replaced by its out-of-place equivalent:intermediate_updated=intermediate.foo().

functionalize is useful for shipping a pytorch program off tobackends or compilers that aren’t able to easily representmutations or aliasing operators.

Parameters
  • func (Callable) – A Python function that takes one or more arguments.

  • remove (str) – An optional string argument, that takes on eitherthe value ‘mutations’ or ‘mutations_and_views’.If ‘mutations’ is passed in then all mutating operatorswill be replaced with their non-mutating equivalents.If ‘mutations_and_views’ is passed in, then additionally, all aliasingoperators will be replaced with their non-aliasing equivalents.Default: ‘mutations’.

Returns

Returns a new “functionalized” function. It takes the same inputs asfunc, and has the same behavior, but any mutations(and optionally aliasing) performed on intermediate tensorsin the function will be removed.

Return type

Callable

functionalize will also remove mutations (and views) that were performed on function inputs.However to preserve semantics, functionalize will “fix up” the mutations afterthe transform has finished running, by detecting if any tensor inputs “should have”been mutated, and copying the new data back to the inputs if necessary.

Example:

>>>importtorch>>>fromtorch.fx.experimental.proxy_tensorimportmake_fx>>>fromtorch.funcimportfunctionalize>>>>>># A function that uses mutations and views, but only on intermediate tensors.>>>deff(a):...b=a+1...c=b.view(-1)...c.add_(1)...returnb...>>>inpt=torch.randn(2)>>>>>>out1=f(inpt)>>>out2=functionalize(f)(inpt)>>>>>># semantics are the same (outputs are equivalent)>>>print(torch.allclose(out1,out2))True>>>>>>f_traced=make_fx(f)(inpt)>>>f_no_mutations_traced=make_fx(functionalize(f))(inpt)>>>f_no_mutations_and_views_traced=make_fx(functionalize(f,remove='mutations_and_views'))(inpt)>>>>>>print(f_traced.code)def forward(self, a_1):    add = torch.ops.aten.add(a_1, 1);  a_1 = None    view = torch.ops.aten.view(add, [-1])    add_ = torch.ops.aten.add_(view, 1);  view = None    return add>>>print(f_no_mutations_traced.code)def forward(self, a_1):    add = torch.ops.aten.add(a_1, 1);  a_1 = None    view = torch.ops.aten.view(add, [-1]);  add = None    add_1 = torch.ops.aten.add(view, 1);  view = None    view_1 = torch.ops.aten.view(add_1, [2]);  add_1 = None    return view_1>>>print(f_no_mutations_and_views_traced.code)def forward(self, a_1):    add = torch.ops.aten.add(a_1, 1);  a_1 = None    view_copy = torch.ops.aten.view_copy(add, [-1]);  add = None    add_1 = torch.ops.aten.add(view_copy, 1);  view_copy = None    view_copy_1 = torch.ops.aten.view_copy(add_1, [2]);  add_1 = None    return view_copy_1>>># A function that mutates its input tensor>>>deff(a):...b=a.view(-1)...b.add_(1)...returna...>>>f_no_mutations_and_views_traced=make_fx(functionalize(f,remove='mutations_and_views'))(inpt)>>>#>>># All mutations and views have been removed,>>># but there is an extra copy_ in the graph to correctly apply the mutation to the input>>># after the function has completed.>>>print(f_no_mutations_and_views_traced.code)def forward(self, a_1):    view_copy = torch.ops.aten.view_copy(a_1, [-1])    add = torch.ops.aten.add(view_copy, 1);  view_copy = None    view_copy_1 = torch.ops.aten.view_copy(add, [2]);  add = None    copy_ = torch.ops.aten.copy_(a_1, view_copy_1);  a_1 = None    return view_copy_1
There are a few “failure modes” for functionalize that are worth calling out:
  1. Like other torch.func transforms,functionalize() doesn’t work with functionsthat directly use.backward(). The same is true for torch.autograd.grad.If you want to use autograd, you can compute gradients directlywithfunctionalize(grad(f)).

  2. Like other torch.func transforms,functionalize() doesn’t work with global state.If you callfunctionalize(f) on a function that takes views / mutations ofnon-local state, functionalization will simply no-op and pass the view/mutationcalls directly to the backend.One way to work around this is is to ensure that any non-local state creationis wrapped into a larger function, which you then call functionalize on.

  3. resize_() has some limitations: functionalize will only work on programsthat use resize_()` as long as the tensor being resized is not a view.

  4. as_strided() has some limitations: functionalize will not work onas_strided() calls that result in tensors with overlapping memory.

Finally, a helpful mental model for understanding functionalization is thatmost user pytorch programs are writing with the public torch API.When executed, torch operators are generally decomposed intoour internal C++ “ATen” API.The logic for functionalization happens entirely at the level of ATen.Functionalization knows how to take every aliasing operator in ATen,and map it to its non-aliasing equivalent(e.g.tensor.view({-1}) ->at::view_copy(tensor,{-1})),and how to take every mutating operator in ATen,and map it to its non-mutating equivalent(e.g.tensor.add_(1) ->at::add(tensor,-1)),while tracking aliases and mutations out-of-line to know when to fix things up.Information about which ATen operators are aliasing or mutating all comes frompytorch/pytorch.