Rate this Page

Extending torch.func with autograd.Function#

Created On: Jan 03, 2023 | Last Updated On: Sep 14, 2023

So you’d like to usetorch.autograd.Function with thetorch.functransforms liketorch.vmap(),torch.func.grad(), etc.

There are two main use cases:

  • you wish to call code that does not contain PyTorch operations andhave it work with function transforms. That is, thetorch.autograd.Function’sforward/backward/etc calls into functions from other systems like C++, CUDA, numpy.

  • you wish to specify custom gradient rules, likeJAX’scustom_vjp/custom_jvp

PyTorch combines both of these concepts intotorch.autograd.Function.

Basic Usage#

This guide assumes you are familiar withExtending torch.autograd,which explains how to usetorch.autograd.Function.

torch.autograd.Function can either have aforward() that accepts a ctx object,or it can have separateforward() (that does not acceptctx) and asetup_context()staticmethod that modifies thectx object.

Only the latter is supported with function transforms:

  • forward() is the code that performs the operation and it should not acceptactx object.

  • setup_context(ctx,inputs,output) is the code where you cancall methods onctx. Here is where you should save Tensors for backward(by callingctx.save_for_backward(*tensors)), or save non-Tensors(by assigning them to thectx object).

Becausesetup_context() accepts onlyinputs andoutput,the only quantities that can be saved are either objects (such as Tensors) inthe inputs or outputs or quantities (likeTensor.shape) derived from them.If you wish to save a non-input intermediate activation fromFunction.forward() for backward, then you’ll need to return it as anoutput fromforward() so that it gets passed tosetup_context().

Depending on the transform,

In order for thetorch.autograd.Function to be arbitrarily composable with functiontransforms, we recommend that all other staticmethods other thanforward() andsetup_context() must be transformable: that is, they must consist of only PyTorchoperators or call othertorch.autograd.Function (that may call into C++/CUDA/etc).

Let’s go over some examples of common use cases.

Example 1: autograd.Function calls into another system#

A common case is atorch.autograd.Function with both forward() and backward() callinginto another system (like C++, CUDA, numpy, triton).

importtorchimportnumpyasnpdefto_numpy(tensor):returntensor.cpu().numpy()classNumpySort(torch.autograd.Function):# Note that forward does not take ctx@staticmethoddefforward(x,dim):device=x.devicex=to_numpy(x)ind=np.argsort(x,axis=dim)ind_inv=np.argsort(ind,axis=dim)result=np.take_along_axis(x,ind,axis=dim)# Any intermediates to be saved in backward must be returned as# outputs.return(# The desired outputtorch.tensor(result,device=device),# intermediate to save for backwardtorch.tensor(ind,device=device),# intermediate to save for backwardtorch.tensor(ind_inv,device=device),)# setup_context is responsible for calling methods and/or assigning to# the ctx object. Please do not do additional compute (e.g. add# Tensors together) in setup_context.@staticmethoddefsetup_context(ctx,inputs,output):x,dim=inputs# Note that output is whatever you returned from forward.# If you returned multiple values, then output is a Tuple of multiple values.# If you returned a single Tensor, then output is a Tensor.# If you returned a Tuple with a single Tensor, then output is a# Tuple with a single Tensor._,ind,ind_inv=outputctx.mark_non_differentiable(ind,ind_inv)# Tensors must be saved via ctx.save_for_backward. Please do not# assign them directly onto the ctx object.ctx.save_for_backward(ind,ind_inv)# Non-tensors may be saved by assigning them as attributes on the ctx object.ctx.dim=dim@staticmethoddefbackward(ctx,grad_output,_0,_1):# For the autograd.Function to be arbitrarily composable with function# transforms, all staticmethod other than forward and setup_context# must be implemented in a "transformable" way; that is, they must# only consist of PyTorch operations or autograd.Function.## For example, this allows us to do double backwards and/or compute# second order gradients.## We've written the backward pass of NumpySort in terms of another# autograd.Function, NumpyTake.ind,ind_inv=ctx.saved_tensorsreturnNumpyTake.apply(grad_output,ind_inv,ind,ctx.dim),NoneclassNumpyTake(torch.autograd.Function):@staticmethoddefforward(x,ind,ind_inv,dim):device=x.devicex=to_numpy(x)ind=to_numpy(ind)returntorch.tensor(np.take_along_axis(x,ind,dim),device=device)@staticmethoddefsetup_context(ctx,inputs,output):x,ind,ind_inv,dim=inputsctx.save_for_backward(ind,ind_inv)ctx.dim=dim@staticmethoddefbackward(ctx,grad_output):ind,ind_inv=ctx.saved_tensorsresult=NumpyTake.apply(grad_output,ind_inv,ind,ctx.dim)returnresult,None,None,None

Now, to make it easier to useNumpySort (to hide away the intermediates wereturned as outputs, as well as allow default args and kwargs), we create a newfunction that invokes it:

defnumpy_sort(x,dim=-1):result,_,_=NumpySort.apply(x,dim)returnresult

And here’s a sanity check:

x=torch.randn(2,3)grad_x=torch.func.grad(lambdax:numpy_sort(x).sum())(x)asserttorch.allclose(grad_x,torch.ones_like(x))

Example 2: autograd.Function specifies custom gradient rules#

Another common case is antorch.autograd.Function that is implemented with PyTorchoperations. PyTorch is able to compute gradients for PyTorch operations automatically,but perhaps we wish to customize how the gradients are computed. Some reasons whywe may want a custom backward different from the one PyTorch gives us are:

  • improving numeric stability

  • changing the performance characteristics of the backward

  • changing how edge cases are handled (e.g. nans, inf)

  • modifying the gradient (e.g. gradient clipping)

Here’s an example of antorch.autograd.Function for the functiony=x**3 where wechange the performance characteristics (some computation that would normally happenduring the backward pass, computing dx, happens in the forward pass).

classMyCube(torch.autograd.Function):@staticmethoddefforward(x):result=x**3# In regular PyTorch, if we had just run y = x ** 3, then the backward# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done# that computation here in the forward pass instead.dx=3*x**2returnresult,dx@staticmethoddefsetup_context(ctx,inputs,output):x,=inputsresult,dx=outputctx.save_for_backward(x,dx)@staticmethoddefbackward(ctx,grad_output,grad_dx):x,dx=ctx.saved_tensors# In order for the autograd.Function to work with higher-order# gradients, we must add the gradient contribution of `dx`.result=grad_output*dx+grad_dx*6*xreturnresult

Now, to make it easier to useNumpySort (and hide away the intermediates wereturned as outputs) we create a new function that invokes it:

defmy_cube(x):result,_=MyCube.apply(x)returnresult

Here’s a sanity check computing the second-order gradients:

x=torch.randn([])ggx=torch.func.grad(torch.func.grad(my_cube))(x)asserttorch.allclose(ggx,6*x)

Limitations and gotchas#

Warning

Please read these limitations oftorch.autograd.Function with torch.func transformscarefully. We are not able to catch many of these situations and error outgracefully so they will lead to undefined behavior.

Please do not capture Tensors that are being transformed over, haverequires_grad=True, or are dual tensors, into the methods of thetorch.autograd.Function. The way to be completely safe is to ensure that the onlyTensors being used inside any method of thetorch.autograd.Function must be directlypassed as inputs (or via the ctx object) rather than come from outsidethetorch.autograd.Function.

torch.autograd.Function does not handle Tensors in pytrees (arbitrary nestedPython data structures that may or may not contain Tensors). Forthose Tensors to be tracked by autograd, they must be passed directly asan argument totorch.autograd.Function. This is in contrast tojax.{custom_vjp, custom_jvp}, which do accept pytrees.

Please only usesave_for_backward() orsave_for_forward() to save Tensors.Please do not assign Tensors or collections of Tensors directly onto the ctx object -these Tensors will not get tracked

torch.vmap() Support#

To use antorch.autograd.Function withtorch.vmap(), you must either:

Automatically generate a vmap rule#

If yourtorch.autograd.Function fulfills the following additional constraints, then weare able to generate a vmap rule for it. If it doesn’t fulfill the constraints or if youwant custom behavior under vmap, please manually define a vmap staticmethod (see next section).

Warning

We are not easily able to check for the following constraints and errorout gracefully. Violation of the constraints may lead to undefinedbehavior.

Example:

classMyCube(torch.autograd.Function):# Set generate_vmap_rule to True to ask PyTorch to automatically generate# a vmap rule.generate_vmap_rule=True@staticmethoddefforward(x):result=x**3dx=3*x**2returnresult,dx@staticmethoddefsetup_context(ctx,inputs,output):x,=inputsresult,dx=outputctx.save_for_backward(x,dx)@staticmethoddefbackward(ctx,grad_output,grad_dx):x,dx=ctx.saved_tensorsresult=grad_output*dx+grad_dx*6*xreturnresultdefmy_cube(x):result,dx=MyCube.apply(x)returnresultx=torch.randn(3)result=torch.vmap(my_cube)(x)asserttorch.allclose(result,x**3)

Defining the vmap staticmethod#

If yourtorch.autograd.Function calls into another system (like NumPy, C++, CUDA, triton),then to get it to work withtorch.vmap() or transforms that use it, you’llneed to manually define avmap() staticmethod.

Depending on what transforms you want to use and your use case, you may not needto add avmap() staticmethod to all of yourtorch.autograd.Function:

We do recommend ensuring all of yourtorch.autograd.Function have support fortorch.vmap() though, especially if you are writing a third-party library and you want yourtorch.autograd.Function to work with all combinations oftorch.func() transforms.

Conceptually, the vmap staticmethod is responsible for defining how theforward()should behave undertorch.vmap(). That is, it defines how to transformtheforward() to run over inputs with an additional dimension (the dimensionbeing vmapped over). This is similar to howtorch.vmap() is implemented overPyTorch operations: for each operation, we define a vmap rule (sometimes alsoreferred to as a “batching rule”).

Here’s how to define thevmap() staticmethod:

  • the signature isvmap(info,in_dims:Tuple[Optional[int]],*args), where*args is the same as the args toforward().

  • The vmap staticmethod is responsible for defining how theforward() should behaveundertorch.vmap(). That is, given inputs with an additional dimension(specified byin_dims), how do we compute the batched version offorward()?

  • For each arg inargs,in_dims has a correspondingOptional[int].It isNone if the arg is not a Tensor or if the arg is not being vmapped over,otherwise, it is an integer specifying what dimension of the Tensor is being vmappedover.

  • info is a collection of additional metadata that may be helpful:info.batch_size specifies the size of the dimension being vmapped over, whileinfo.randomness is therandomness option that was passed totorch.vmap().

  • The return of the vmap staticmethod is a tuple of(output,out_dims). Similartoin_dims,out_dims should be of the same structure asoutput and containoneout_dim per output that specifies if the output has the vmappeddimension and what index it is in.

Example:

defto_numpy(tensor):returntensor.cpu().numpy()classNumpySort(torch.autograd.Function):@staticmethoddefforward(x,dim):device=x.devicex=to_numpy(x)ind=np.argsort(x,axis=dim)ind_inv=np.argsort(ind,axis=dim)result=np.take_along_axis(x,ind,axis=dim)return(torch.tensor(result,device=device),torch.tensor(ind,device=device),torch.tensor(ind_inv,device=device),)@staticmethoddefsetup_context(ctx,inputs,output):x,dim=inputs_,ind,ind_inv=outputctx.mark_non_differentiable(ind,ind_inv)ctx.save_for_backward(ind,ind_inv)ctx.dim=dim@staticmethoddefbackward(ctx,grad_output,_0,_1):ind,ind_inv=ctx.saved_tensorsreturnNumpyTake.apply(grad_output,ind_inv,ind,ctx.dim),None# The signature of the vmap staticmethod is:# vmap(info, in_dims: Tuple[Optional[int]], *args)# where *args is the same as the arguments to `forward`.@staticmethoddefvmap(info,in_dims,x,dim):# For every input (x and dim), in_dims stores an Optional[int]# that is:# - None if the input is not being vmapped over or if the input#   is not a Tensor# - an integer if the input is being vmapped over that represents#   the index of the dimension being vmapped over.x_bdim,_=in_dims# A "vmap rule" is the logic of how to perform the operation given# inputs with one additional dimension. In NumpySort, x has an# additional dimension (x_bdim). The vmap rule is simply# to call NumpySort again but pass it a different `dim`.x=x.movedim(x_bdim,0)# Handle negative dims correctlydim=dimifdim>=0elsedim+x.dim()-1result=NumpySort.apply(x,dim+1)# The vmap rule must return a tuple of two things# 1. the output. Should be the same amount of things#    as returned by the forward().# 2. one Optional[int] for each output specifying if each output# is being vmapped over, and if so, the index of the# dimension being vmapped over.## NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the# dimension being vmapped over to the front of `x`, that appears at# dimension 0 of all outputs.# The return is (output, out_dims) -- output is a tuple of 3 Tensors# and out_dims is a Tuple of 3 Optional[int]returnNumpySort.apply(x,dim+1),(0,0,0)classNumpyTake(torch.autograd.Function):@staticmethoddefforward(x,ind,ind_inv,dim):device=x.devicex=to_numpy(x)ind=to_numpy(ind)returntorch.tensor(np.take_along_axis(x,ind,dim),device=device)@staticmethoddefsetup_context(ctx,inputs,output):x,ind,ind_inv,dim=inputsctx.save_for_backward(ind,ind_inv)ctx.dim=dim@staticmethoddefbackward(ctx,grad_output):ind,ind_inv=ctx.saved_tensorsresult=NumpyTake.apply(grad_output,ind_inv,ind,ctx.dim)returnresult,None,None,None@staticmethoddefvmap(info,in_dims,x,ind,ind_inv,dim):x_bdim,ind_bdim,ind_inv_bdim,_=in_dims# The strategy is: expand {x, ind, ind_inv} to all have the dimension# being vmapped over.# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).# Handle negative dims by wrapping them to be positivelogical_dim=x.dim()ifx_bdimisNoneelsex_bdim-1dim=dimifdim>=0elsedim+logical_dimdefmaybe_expand_bdim_at_front(x,x_bdim):ifx_bdimisNone:returnx.expand(info.batch_size,*x.shape)returnx.movedim(x_bdim,0)# If the Tensor doesn't have the dimension being vmapped over,# expand it out. Otherwise, move it to the front of the Tensorx=maybe_expand_bdim_at_front(x,x_bdim)ind=maybe_expand_bdim_at_front(ind,ind_bdim)ind_inv=maybe_expand_bdim_at_front(ind_inv,ind_inv_bdim)# The return is a tuple (output, out_dims). Since output is a Tensor,# then out_dims is an Optional[int] (instead of being a Tuple).returnNumpyTake.apply(x,ind,ind_inv,dim+1),0defnumpy_sort(x,dim=-1):result,_,_=NumpySort.apply(x,dim)returnresultx=torch.randn(2,3)result=torch.vmap(numpy_sort)(x)asserttorch.allclose(result,numpy_sort(result,1))

Note

The vmap staticmethod should aim to preserve the semantics of theentireFunction. That is, (pseudocode)grad(vmap(MyFunc))should be replaceable with agrad(map(MyFunc)).

If your autograd.Function has any custom behavior in the backward pass, pleasekeep this in mind.

Note

It is a legitimate use case to write a custom vmap staticmethod for aFunction that PyTorch is able to generate a vmaprule for viagenerate_vmap_rule=True. You may wish to do this if thegenerated vmap rule doesn’t have the semantics you’re looking for.

torch.func.jvp() Support#

To support forward-mode AD, atorch.autograd.Function must have ajvp() staticmethod.Please seeForward mode AD for details.