torch.func#
Created On: Jun 11, 2025 | Last Updated On: Jun 11, 2025
torch.func, previously known as “functorch”, isJAX-like composable function transforms for PyTorch.
Note
This library is currently inbeta.What this means is that the features generally work (unless otherwise documented)and we (the PyTorch team) are committed to bringing this library forward. However, the APIsmay change under user feedback and we don’t have full coverage over PyTorch operations.
If you have suggestions on the API or use-cases you’d like to be covered, pleaseopen a GitHub issue or reach out. We’d love to hear about how you’re using the library.
What are composable function transforms?#
A “function transform” is a higher-order function that accepts a numerical functionand returns a new function that computes a different quantity.
torch.funchas auto-differentiation transforms (grad(f)returns a function thatcomputes the gradient off), a vectorization/batching transform (vmap(f)returns a function that computesfover batches of inputs), and others.These function transforms can compose with each other arbitrarily. For example,composing
vmap(grad(f))computes a quantity called per-sample-gradients thatstock PyTorch cannot efficiently compute today.
Why composable function transforms?#
There are a number of use cases that are tricky to do in PyTorch today:
computing per-sample-gradients (or other per-sample quantities)
running ensembles of models on a single machine
efficiently batching together tasks in the inner-loop of MAML
efficiently computing Jacobians and Hessians
efficiently computing batched Jacobians and Hessians
Composingvmap(),grad(), andvjp() transforms allows us to express the above without designing a separate subsystem for each.This idea of composable function transforms comes from theJAX framework.