torch.func API Reference#
Created On: Jun 11, 2025 | Last Updated On: Jun 11, 2025
Function Transforms#
vmap | vmap is the vectorizing map; |
grad |
|
grad_and_value | Returns a function to compute a tuple of the gradient and primal, or forward, computation. |
vjp | Standing for the vector-Jacobian product, returns a tuple containing the results of |
jvp | Standing for the Jacobian-vector product, returns a tuple containing the output offunc(*primals) and the "Jacobian of |
linearize | Returns the value of |
jacrev | Computes the Jacobian of |
jacfwd | Computes the Jacobian of |
hessian | Computes the Hessian of |
functionalize | functionalize is a transform that can be used to remove (intermediate) mutations and aliasing from a function, while preserving the function's semantics. |
Utilities for working with torch.nn.Modules#
In general, you can transform over a function that calls atorch.nn.Module.For example, the following is an example of computing a jacobian of a functionthat takes three values and returns three values:
model=torch.nn.Linear(3,3)deff(x):returnmodel(x)x=torch.randn(3)jacobian=jacrev(f)(x)assertjacobian.shape==(3,3)
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That’s whatfunctional_call() is for: it accepts an nn.Module, the transformedparameters, and the inputs to the Module’s forward pass. It returns the value of running the Module’s forward pass with the replaced parameters.
Here’s how we would compute the Jacobian over the parameters
model=torch.nn.Linear(3,3)deff(params,x):returntorch.func.functional_call(model,params,x)x=torch.randn(3)jacobian=jacrev(f)(dict(model.named_parameters()),x)
functional_call | Performs a functional call on the module by replacing the module parameters and buffers with the provided ones. |
stack_module_state | Prepares a list of torch.nn.Modules for ensembling with |
replace_all_batch_norm_modules_ | In place updates |
If you’re looking for information on fixing Batch Norm modules, please follow theguidance here
Debug utilities#
debug_unwrap | Unwraps a functorch tensor (e.g. |