Rate this Page

Migrating from functorch to torch.func#

Created On: Jun 11, 2025 | Last Updated On: Jun 11, 2025

torch.func, previously known as “functorch”, isJAX-like composable function transforms for PyTorch.

functorch started as an out-of-tree library over atthepytorch/functorch repository.Our goal has always been to upstream functorch directly into PyTorch and provideit as a core PyTorch library.

As the final step of the upstream, we’ve decided to migrate from being a top level package(functorch) to being a part of PyTorch to reflect how the function transforms areintegrated directly into PyTorch core. As of PyTorch 2.0, we are deprecatingimportfunctorch and ask that users migrate to the newest APIs, which wewill maintain going forward.importfunctorch will be kept around to maintainbackwards compatibility for a couple of releases.

function transforms#

The following APIs are a drop-in replacement for the followingfunctorch APIs.They are fully backwards compatible.

functorch API

PyTorch API (as of PyTorch 2.0)

functorch.vmap

torch.vmap() ortorch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

Furthermore, if you are using torch.autograd.functional APIs, please try outthetorch.func equivalents instead.torch.func functiontransforms are more composable and more performant in many cases.

NN module utilities#

We’ve changed the APIs to apply function transforms over NN modules to make themfit better into the PyTorch design philosophy. The new API is different, soplease read this section carefully.

functorch.make_functional#

torch.func.functional_call() is the replacement forfunctorch.make_functionalandfunctorch.make_functional_with_buffers.However, it is not a drop-in replacement.

If you’re in a hurry, you can usehelper functions in this gistthat emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.We recommend usingtorch.func.functional_call() directly because it is a more explicitand flexible API.

Concretely, functorch.make_functional returns a functional module and parameters.The functional module accepts parameters and inputs to the model as arguments.torch.func.functional_call() allows one to call the forward pass of an existingmodule using new parameters and buffers and inputs.

Here’s an example of how to compute gradients of parameters of a model using functorchvstorch.func:

# ---------------# using functorch# ---------------importtorchimportfunctorchinputs=torch.randn(64,3)targets=torch.randn(64,3)model=torch.nn.Linear(3,3)fmodel,params=functorch.make_functional(model)defcompute_loss(params,inputs,targets):prediction=fmodel(params,inputs)returntorch.nn.functional.mse_loss(prediction,targets)grads=functorch.grad(compute_loss)(params,inputs,targets)# ------------------------------------# using torch.func (as of PyTorch 2.0)# ------------------------------------importtorchinputs=torch.randn(64,3)targets=torch.randn(64,3)model=torch.nn.Linear(3,3)params=dict(model.named_parameters())defcompute_loss(params,inputs,targets):prediction=torch.func.functional_call(model,params,(inputs,))returntorch.nn.functional.mse_loss(prediction,targets)grads=torch.func.grad(compute_loss)(params,inputs,targets)

And here’s an example of how to compute jacobians of model parameters:

# ---------------# using functorch# ---------------importtorchimportfunctorchinputs=torch.randn(64,3)model=torch.nn.Linear(3,3)fmodel,params=functorch.make_functional(model)jacobians=functorch.jacrev(fmodel)(params,inputs)# ------------------------------------# using torch.func (as of PyTorch 2.0)# ------------------------------------importtorchfromtorch.funcimportjacrev,functional_callinputs=torch.randn(64,3)model=torch.nn.Linear(3,3)params=dict(model.named_parameters())# jacrev computes jacobians of argnums=0 by default.# We set it to 1 to compute jacobians of paramsjacobians=jacrev(functional_call,argnums=1)(model,params,(inputs,))

Note that it is important for memory consumption that you should only carryaround a single copy of your parameters.model.named_parameters() does not copythe parameters. If in your model training you update the parameters of the modelin-place, then thenn.Module that is your model has the single copy of theparameters and everything is OK.

However, if you want to carry your parameters around in a dictionary and updatethem out-of-place, then there are two copies of parameters: the one in thedictionary and the one in themodel. In this case, you should changemodel to not hold memory by converting it to the meta device viamodel.to('meta').

functorch.combine_state_for_ensemble#

Please usetorch.func.stack_module_state() instead offunctorch.combine_state_for_ensembletorch.func.stack_module_state() returns two dictionaries, one of stacked parameters, andone of stacked buffers, that can then be used withtorch.vmap() andtorch.func.functional_call()for ensembling.

For example, here is an example of how to ensemble over a very simple model:

importtorchnum_models=5batch_size=64in_features,out_features=3,3models=[torch.nn.Linear(in_features,out_features)foriinrange(num_models)]data=torch.randn(batch_size,3)# ---------------# using functorch# ---------------importfunctorchfmodel,params,buffers=functorch.combine_state_for_ensemble(models)output=functorch.vmap(fmodel,(0,0,None))(params,buffers,data)assertoutput.shape==(num_models,batch_size,out_features)# ------------------------------------# using torch.func (as of PyTorch 2.0)# ------------------------------------importcopy# Construct a version of the model with no memory by putting the Tensors on# the meta device.base_model=copy.deepcopy(models[0])base_model.to('meta')params,buffers=torch.func.stack_module_state(models)# It is possible to vmap directly over torch.func.functional_call,# but wrapping it in a function makes it clearer what is going on.defcall_single_model(params,buffers,data):returntorch.func.functional_call(base_model,(params,buffers),(data,))output=torch.vmap(call_single_model,(0,0,None))(params,buffers,data)assertoutput.shape==(num_models,batch_size,out_features)

functorch.compile#

We are no longer supporting functorch.compile (also known as AOTAutograd)as a frontend for compilation in PyTorch; we have integrated AOTAutogradinto PyTorch’s compilation story. If you are a user, please usetorch.compile() instead.