Rate this Page

torch.func API Reference#

Created On: Jun 11, 2025 | Last Updated On: Jun 11, 2025

Function Transforms#

vmap

vmap is the vectorizing map;vmap(func) returns a new function that mapsfunc over some dimension of the inputs.

grad

grad operator helps computing gradients offunc with respect to the input(s) specified byargnums.

grad_and_value

Returns a function to compute a tuple of the gradient and primal, or forward, computation.

vjp

Standing for the vector-Jacobian product, returns a tuple containing the results offunc applied toprimals and a function that, when givencotangents, computes the reverse-mode Jacobian offunc with respect toprimals timescotangents.

jvp

Standing for the Jacobian-vector product, returns a tuple containing the output offunc(*primals) and the "Jacobian offunc evaluated atprimals" timestangents.

linearize

Returns the value offunc atprimals and linear approximation atprimals.

jacrev

Computes the Jacobian offunc with respect to the arg(s) at indexargnum using reverse mode autodiff

jacfwd

Computes the Jacobian offunc with respect to the arg(s) at indexargnum using forward-mode autodiff

hessian

Computes the Hessian offunc with respect to the arg(s) at indexargnum via a forward-over-reverse strategy.

functionalize

functionalize is a transform that can be used to remove (intermediate) mutations and aliasing from a function, while preserving the function's semantics.

Utilities for working with torch.nn.Modules#

In general, you can transform over a function that calls atorch.nn.Module.For example, the following is an example of computing a jacobian of a functionthat takes three values and returns three values:

model=torch.nn.Linear(3,3)deff(x):returnmodel(x)x=torch.randn(3)jacobian=jacrev(f)(x)assertjacobian.shape==(3,3)

However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That’s whatfunctional_call() is for: it accepts an nn.Module, the transformedparameters, and the inputs to the Module’s forward pass. It returns the value of running the Module’s forward pass with the replaced parameters.

Here’s how we would compute the Jacobian over the parameters

model=torch.nn.Linear(3,3)deff(params,x):returntorch.func.functional_call(model,params,x)x=torch.randn(3)jacobian=jacrev(f)(dict(model.named_parameters()),x)

functional_call

Performs a functional call on the module by replacing the module parameters and buffers with the provided ones.

stack_module_state

Prepares a list of torch.nn.Modules for ensembling withvmap().

replace_all_batch_norm_modules_

In place updatesroot by setting therunning_mean andrunning_var to be None and setting track_running_stats to be False for any nn.BatchNorm module inroot

If you’re looking for information on fixing Batch Norm modules, please follow theguidance here

Debug utilities#

debug_unwrap

Unwraps a functorch tensor (e.g.