Rate this Page

Note

Go to the endto download the full example code.

Forward-mode Automatic Differentiation (Beta)#

Created On: Dec 07, 2021 | Last Updated: Apr 18, 2023 | Last Verified: Nov 05, 2024

This tutorial demonstrates how to use forward-mode AD to computedirectional derivatives (or equivalently, Jacobian-vector products).

The tutorial below uses some APIs only available in versions >= 1.11(or nightly builds).

Also note that forward-mode AD is currently in beta. The API issubject to change and operator coverage is still incomplete.

Basic Usage#

Unlike reverse-mode AD, forward-mode AD computes gradients eagerlyalongside the forward pass. We can use forward-mode AD to compute adirectional derivative by performing the forward pass as before,except we first associate our input with another tensor representingthe direction of the directional derivative (or equivalently, thevin a Jacobian-vector product). When an input, which we call “primal”, isassociated with a “direction” tensor, which we call “tangent”, theresultant new tensor object is called a “dual tensor” for its connectionto dual numbers[0].

As the forward pass is performed, if any input tensors are dual tensors,extra computation is performed to propagate this “sensitivity” of thefunction.

importtorchimporttorch.autograd.forward_adasfwADprimal=torch.randn(10,10)tangent=torch.randn(10,10)deffn(x,y):returnx**2+y**2# All forward AD computation must be performed in the context of# a ``dual_level`` context. All dual tensors created in such a context# will have their tangents destroyed upon exit. This is to ensure that# if the output or intermediate results of this computation are reused# in a future forward AD computation, their tangents (which are associated# with this computation) won't be confused with tangents from the later# computation.withfwAD.dual_level():# To create a dual tensor we associate a tensor, which we call the# primal with another tensor of the same size, which we call the tangent.# If the layout of the tangent is different from that of the primal,# The values of the tangent are copied into a new tensor with the same# metadata as the primal. Otherwise, the tangent itself is used as-is.## It is also important to note that the dual tensor created by# ``make_dual`` is a view of the primal.dual_input=fwAD.make_dual(primal,tangent)assertfwAD.unpack_dual(dual_input).tangentistangent# To demonstrate the case where the copy of the tangent happens,# we pass in a tangent with a layout different from that of the primaldual_input_alt=fwAD.make_dual(primal,tangent.T)assertfwAD.unpack_dual(dual_input_alt).tangentisnottangent# Tensors that do not have an associated tangent are automatically# considered to have a zero-filled tangent of the same shape.plain_tensor=torch.randn(10,10)dual_output=fn(dual_input,plain_tensor)# Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``# as attributesjvp=fwAD.unpack_dual(dual_output).tangentassertfwAD.unpack_dual(dual_output).tangentisNone

Usage with Modules#

To usenn.Module with forward AD, replace the parameters of yourmodel with dual tensors before performing the forward pass. At thetime of writing, it is not possible to create dual tensor`nn.Parameter`s. As a workaround, one must register the dual tensoras a non-parameter attribute of the module.

importtorch.nnasnnmodel=nn.Linear(5,5)input=torch.randn(16,5)params={name:pforname,pinmodel.named_parameters()}tangents={name:torch.rand_like(p)forname,pinparams.items()}withfwAD.dual_level():forname,pinparams.items():delattr(model,name)setattr(model,name,fwAD.make_dual(p,tangents[name]))out=model(input)jvp=fwAD.unpack_dual(out).tangent

Using the functional Module API (beta)#

Another way to usenn.Module with forward AD is to utilizethe functional Module API (also known as the stateless Module API).

fromtorch.funcimportfunctional_call# We need a fresh module because the functional call requires the# the model to have parameters registered.model=nn.Linear(5,5)dual_params={}withfwAD.dual_level():forname,pinparams.items():# Using the same ``tangents`` from the above sectiondual_params[name]=fwAD.make_dual(p,tangents[name])out=functional_call(model,dual_params,input)jvp2=fwAD.unpack_dual(out).tangent# Check our resultsasserttorch.allclose(jvp,jvp2)

Custom autograd Function#

Custom Functions also support forward-mode AD. To create custom Functionsupporting forward-mode AD, register thejvp() static method. It ispossible, but not mandatory for custom Functions to support both forwardand backward AD. See thedocumentationfor more information.

classFn(torch.autograd.Function):@staticmethoddefforward(ctx,foo):result=torch.exp(foo)# Tensors stored in ``ctx`` can be used in the subsequent forward grad# computation.ctx.result=resultreturnresult@staticmethoddefjvp(ctx,gI):gO=gI*ctx.result# If the tensor stored in`` ctx`` will not also be used in the backward pass,# one can manually free it using ``del``delctx.resultreturngOfn=Fn.applyprimal=torch.randn(10,10,dtype=torch.double,requires_grad=True)tangent=torch.randn(10,10)withfwAD.dual_level():dual_input=fwAD.make_dual(primal,tangent)dual_output=fn(dual_input)jvp=fwAD.unpack_dual(dual_output).tangent# It is important to use ``autograd.gradcheck`` to verify that your# custom autograd Function computes the gradients correctly. By default,# ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify# ``check_forward_ad=True`` to also check forward grads. If you did not# implement the backward formula for your function, you can also tell ``gradcheck``# to skip the tests that require backward-mode AD by specifying# ``check_backward_ad=False``, ``check_undefined_grad=False``, and# ``check_batched_grad=False``.torch.autograd.gradcheck(Fn.apply,(primal,),check_forward_ad=True,check_backward_ad=False,check_undefined_grad=False,check_batched_grad=False)
True

Functional API (beta)#

We also offer a higher-level functional API in functorchfor computing Jacobian-vector products that you may find simpler to usedepending on your use case.

The benefit of the functional API is that there isn’t a need to understandor use the lower-level dual tensor API and that you can compose it withotherfunctorch transforms (like vmap);the downside is that it offers you less control.

Note that the remainder of this tutorial will require functorch(pytorch/functorch) to run. Please find installationinstructions at the specified link.

importfunctorchasftprimal0=torch.randn(10,10)tangent0=torch.randn(10,10)primal1=torch.randn(10,10)tangent1=torch.randn(10,10)deffn(x,y):returnx**2+y**2# Here is a basic example to compute the JVP of the above function.# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.primal_out,tangent_out=ft.jvp(fn,(primal0,primal1),(tangent0,tangent1))# ``functorch.jvp`` requires every primal to be associated with a tangent.# If we only want to associate certain inputs to `fn` with tangents,# then we'll need to create a new function that captures inputs without tangents:primal=torch.randn(10,10)tangent=torch.randn(10,10)y=torch.randn(10,10)importfunctoolsnew_fn=functools.partial(fn,y=y)primal_out,tangent_out=ft.jvp(new_fn,(primal,),(tangent,))
/var/lib/workspace/intermediate_source/forward_ad_usage.py:203: FutureWarning:We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html/var/lib/workspace/intermediate_source/forward_ad_usage.py:214: FutureWarning:We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html

Using the functional API with Modules#

To usenn.Module withfunctorch.jvp to compute Jacobian-vector productswith respect to the model parameters, we need to reformulate thenn.Module as a function that accepts both the model parameters and inputsto the module.

model=nn.Linear(5,5)input=torch.randn(16,5)tangents=tuple([torch.rand_like(p)forpinmodel.parameters()])# Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state# (``params`` and buffers) and returns a functional version of the model that# can be invoked like a function.# That is, the returned ``func`` can be invoked like# ``func(params, buffers, input)``.# ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API# that you saw previously and we're working on consolidating the two.func,params,buffers=ft.make_functional_with_buffers(model)# Because ``jvp`` requires every input to be associated with a tangent, we need to# create a new function that, when given the parameters, produces the outputdeffunc_params_only(params):returnfunc(params,buffers,input)model_output,jvp_out=ft.jvp(func_params_only,(params,),(tangents,))
/var/lib/workspace/intermediate_source/forward_ad_usage.py:235: FutureWarning:We've integrated functorch into PyTorch. As the final step of the integration, `functorch.make_functional_with_buffers` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.functional_call` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html/var/lib/workspace/intermediate_source/forward_ad_usage.py:242: FutureWarning:We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html

[0]https://en.wikipedia.org/wiki/Dual_number

Total running time of the script: (0 minutes 0.122 seconds)