torch.func.functional_call#
- torch.func.functional_call(module,parameter_and_buffer_dicts,args=None,kwargs=None,*,tie_weights=True,strict=False)[source]#
Performs a functional call on the module by replacing the module parametersand buffers with the provided ones.
Note
If the module has active parametrizations, passing a value in the
parameter_and_buffer_dictsargument with the name set to the regular parametername will completely disable the parametrization.If you want to apply the parametrization function to the value passedplease set the key as{submodule_name}.parametrizations.{parameter_name}.original.Note
If the module performs in-place operations on parameters/buffers, these will be reflectedin the
parameter_and_buffer_dictsinput.Example:
>>>a={'foo':torch.zeros(())}>>>mod=Foo()# does self.foo = self.foo + 1>>>print(mod.foo)# tensor(0.)>>>functional_call(mod,a,torch.ones(()))>>>print(mod.foo)# tensor(0.)>>>print(a['foo'])# tensor(1.)
Note
If the module has tied weights, whether or not functional_call respects the tying is determined by thetie_weights flag.
Example:
>>>a={'foo':torch.zeros(())}>>>mod=Foo()# has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied>>>print(mod.foo)# tensor(1.)>>>mod(torch.zeros(()))# tensor(2.)>>>functional_call(mod,a,torch.zeros(()))# tensor(0.) since it will change self.foo_tied too>>>functional_call(mod,a,torch.zeros(()),tie_weights=False)# tensor(1.)--self.foo_tied is not updated>>>new_a={'foo':torch.zeros(()),'foo_tied':torch.zeros(())}>>>functional_call(mod,new_a,torch.zeros())# tensor(0.)
An example of passing multiple dictionaries
a=({"weight":torch.ones(1,1)},{"buffer":torch.zeros(1)},)# two separate dictionariesmod=nn.Bar(1,1)# return self.weight @ x + self.bufferprint(mod.weight)# tensor(...)print(mod.buffer)# tensor(...)x=torch.randn((1,1))print(x)functional_call(mod,a,x)# same as xprint(mod.weight)# same as before functional_call
And here is an example of applying the grad transform over the parametersof a model.
importtorchimporttorch.nnasnnfromtorch.funcimportfunctional_call,gradx=torch.randn(4,3)t=torch.randn(4,3)model=nn.Linear(3,3)defcompute_loss(params,x,t):y=functional_call(model,params,x)returnnn.functional.mse_loss(y,t)grad_weights=grad(compute_loss)(dict(model.named_parameters()),x,t)
Note
If the user does not need grad tracking outside of grad transforms, they can detach all of theparameters for better performance and memory usage
Example:
>>>detached_params={k:v.detach()fork,vinmodel.named_parameters()}>>>grad_weights=grad(compute_loss)(detached_params,x,t)>>>grad_weights.grad_fn# None--it's not tracking gradients outside of grad
This means that the user cannot call
grad_weight.backward(). However, if they don’t need autograd trackingoutside of the transforms, this will result in less memory usage and faster speeds.- Parameters
module (torch.nn.Module) – the module to call
parameters_and_buffer_dicts (Dict[str,Tensor] ortuple ofDict[str,Tensor]) – the parameters that will be used inthe module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries canbe used together
args (Any ortuple) – arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict) – keyword arguments to be passed to the module call
tie_weights (bool,optional) – If True, then parameters and buffers tied in the original model will be treated astied in the reparameterized version. Therefore, if True and different values are passed for the tiedparameters and buffers, it will error. If False, it will not respect the originally tied parameters andbuffers unless the values passed for both weights are the same. Default: True.
strict (bool,optional) – If True, then the parameters and buffers passed in must match the parameters andbuffers in the original module. Therefore, if True and there are any missing or unexpected keys, it willerror. Default: False.
- Returns
the result of calling
module.- Return type
Any