Rate this Page

torch.func.functional_call#

torch.func.functional_call(module,parameter_and_buffer_dicts,args=None,kwargs=None,*,tie_weights=True,strict=False)[source]#

Performs a functional call on the module by replacing the module parametersand buffers with the provided ones.

Note

If the module has active parametrizations, passing a value in theparameter_and_buffer_dicts argument with the name set to the regular parametername will completely disable the parametrization.If you want to apply the parametrization function to the value passedplease set the key as{submodule_name}.parametrizations.{parameter_name}.original.

Note

If the module performs in-place operations on parameters/buffers, these will be reflectedin theparameter_and_buffer_dicts input.

Example:

>>>a={'foo':torch.zeros(())}>>>mod=Foo()# does self.foo = self.foo + 1>>>print(mod.foo)# tensor(0.)>>>functional_call(mod,a,torch.ones(()))>>>print(mod.foo)# tensor(0.)>>>print(a['foo'])# tensor(1.)

Note

If the module has tied weights, whether or not functional_call respects the tying is determined by thetie_weights flag.

Example:

>>>a={'foo':torch.zeros(())}>>>mod=Foo()# has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied>>>print(mod.foo)# tensor(1.)>>>mod(torch.zeros(()))# tensor(2.)>>>functional_call(mod,a,torch.zeros(()))# tensor(0.) since it will change self.foo_tied too>>>functional_call(mod,a,torch.zeros(()),tie_weights=False)# tensor(1.)--self.foo_tied is not updated>>>new_a={'foo':torch.zeros(()),'foo_tied':torch.zeros(())}>>>functional_call(mod,new_a,torch.zeros())# tensor(0.)

An example of passing multiple dictionaries

a=({"weight":torch.ones(1,1)},{"buffer":torch.zeros(1)},)# two separate dictionariesmod=nn.Bar(1,1)# return self.weight @ x + self.bufferprint(mod.weight)# tensor(...)print(mod.buffer)# tensor(...)x=torch.randn((1,1))print(x)functional_call(mod,a,x)# same as xprint(mod.weight)# same as before functional_call

And here is an example of applying the grad transform over the parametersof a model.

importtorchimporttorch.nnasnnfromtorch.funcimportfunctional_call,gradx=torch.randn(4,3)t=torch.randn(4,3)model=nn.Linear(3,3)defcompute_loss(params,x,t):y=functional_call(model,params,x)returnnn.functional.mse_loss(y,t)grad_weights=grad(compute_loss)(dict(model.named_parameters()),x,t)

Note

If the user does not need grad tracking outside of grad transforms, they can detach all of theparameters for better performance and memory usage

Example:

>>>detached_params={k:v.detach()fork,vinmodel.named_parameters()}>>>grad_weights=grad(compute_loss)(detached_params,x,t)>>>grad_weights.grad_fn# None--it's not tracking gradients outside of grad

This means that the user cannot callgrad_weight.backward(). However, if they don’t need autograd trackingoutside of the transforms, this will result in less memory usage and faster speeds.

Parameters
  • module (torch.nn.Module) – the module to call

  • parameters_and_buffer_dicts (Dict[str,Tensor] ortuple ofDict[str,Tensor]) – the parameters that will be used inthe module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries canbe used together

  • args (Any ortuple) – arguments to be passed to the module call. If not a tuple, considered a single argument.

  • kwargs (dict) – keyword arguments to be passed to the module call

  • tie_weights (bool,optional) – If True, then parameters and buffers tied in the original model will be treated astied in the reparameterized version. Therefore, if True and different values are passed for the tiedparameters and buffers, it will error. If False, it will not respect the originally tied parameters andbuffers unless the values passed for both weights are the same. Default: True.

  • strict (bool,optional) – If True, then the parameters and buffers passed in must match the parameters andbuffers in the original module. Therefore, if True and there are any missing or unexpected keys, it willerror. Default: False.

Returns

the result of callingmodule.

Return type

Any