Note
Go to the endto download the full example code.
Jacobians, Hessians, hvp, vhp, and more: composing function transforms#
Created On: Mar 15, 2023 | Last Updated: Apr 18, 2023 | Last Verified: Nov 05, 2024
Computing jacobians or hessians are useful in a number of non-traditionaldeep learning models. It is difficult (or annoying) to compute these quantitiesefficiently using PyTorch’s regular autodiff APIs(Tensor.backward(),torch.autograd.grad). PyTorch’sJAX-inspiredfunction transforms APIprovides ways of computing various higher-order autodiff quantitiesefficiently.
Note
This tutorial requires PyTorch 2.0.0 or later.
Computing the Jacobian#
importtorchimporttorch.nn.functionalasFfromfunctoolsimportpartial_=torch.manual_seed(0)
Let’s start with a function that we’d like to compute the jacobian of.This is a simple linear function with non-linear activation.
Let’s add some dummy data: a weight, a bias, and a feature vector x.
D=16weight=torch.randn(D,D)bias=torch.randn(D)x=torch.randn(D)# feature vector
Let’s think ofpredict as a function that maps the inputx from\(R^D \to R^D\).PyTorch Autograd computes vector-Jacobian products. In order to compute the fullJacobian of this\(R^D \to R^D\) function, we would have to compute it row-by-rowby using a different unit vector each time.
defcompute_jac(xp):jacobian_rows=[torch.autograd.grad(predict(weight,bias,xp),xp,vec)[0]forvecinunit_vectors]returntorch.stack(jacobian_rows)xp=x.clone().requires_grad_()unit_vectors=torch.eye(D)jacobian=compute_jac(xp)print(jacobian.shape)print(jacobian[0])# show first row
torch.Size([16, 16])tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190, 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
Instead of computing the jacobian row-by-row, we can use PyTorch’storch.vmap function transform to get rid of the for-loop and vectorize thecomputation. We can’t directly applyvmap totorch.autograd.grad;instead, PyTorch provides atorch.func.vjp transform that composes withtorch.vmap:
fromtorch.funcimportvmap,vjp_,vjp_fn=vjp(partial(predict,weight,bias),x)ft_jacobian,=vmap(vjp_fn)(unit_vectors)# let's confirm both methods compute the same resultasserttorch.allclose(ft_jacobian,jacobian)
In a later tutorial a composition of reverse-mode AD andvmap will give usper-sample-gradients.In this tutorial, composing reverse-mode AD andvmap gives us Jacobiancomputation!Various compositions ofvmap and autodiff transforms can give us differentinteresting quantities.
PyTorch providestorch.func.jacrev as a convenience function that performsthevmap-vjp composition to compute jacobians.jacrev accepts anargnumsargument that says which argument we would like to compute Jacobians withrespect to.
fromtorch.funcimportjacrevft_jacobian=jacrev(predict,argnums=2)(weight,bias,x)# Confirm by running the following:asserttorch.allclose(ft_jacobian,jacobian)
Let’s compare the performance of the two ways to compute the jacobian.The function transform version is much faster (and becomes even faster themore outputs there are).
In general, we expect that vectorization viavmap can help eliminate overheadand give better utilization of your hardware.
vmap does this magic by pushing the outer loop down into the function’sprimitive operations in order to obtain better performance.
Let’s make a quick function to evaluate performance and deal withmicroseconds and milliseconds measurements:
defget_perf(first,first_descriptor,second,second_descriptor):"""takes torch.benchmark objects and compares delta of second vs first."""faster=second.times[0]slower=first.times[0]gain=(slower-faster)/slowerifgain<0:gain*=-1final_gain=gain*100print(f" Performance delta:{final_gain:.4f} percent improvement with{second_descriptor} ")
And then run the performance comparison:
fromtorch.utils.benchmarkimportTimerwithout_vmap=Timer(stmt="compute_jac(xp)",globals=globals())with_vmap=Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)",globals=globals())no_vmap_timer=without_vmap.timeit(500)with_vmap_timer=with_vmap.timeit(500)print(no_vmap_timer)print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fe6fc8ad3c0>compute_jac(xp) 1.38 ms 1 measurement, 500 runs , 1 thread<torch.utils.benchmark.utils.common.Measurement object at 0x7fe7350e7fd0>jacrev(predict, argnums=2)(weight, bias, x) 418.32 us 1 measurement, 500 runs , 1 thread
Let’s do a relative performance comparison of the above with ourget_perf function:
get_perf(no_vmap_timer,"without vmap",with_vmap_timer,"vmap")
Performance delta: 69.7824 percent improvement with vmap
Furthermore, it’s pretty easy to flip the problem around and say we want tocompute Jacobians of the parameters to our model (weight, bias) instead of the input
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and biasft_jac_weight,ft_jac_bias=jacrev(predict,argnums=(0,1))(weight,bias,x)
Reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)#
We offer two APIs to compute jacobians:jacrev andjacfwd:
jacrevuses reverse-mode AD. As you saw above it is a composition of ourvjpandvmaptransforms.jacfwduses forward-mode AD. It is implemented as a composition of ourjvpandvmaptransforms.
jacfwd andjacrev can be substituted for each other but they have differentperformance characteristics.
As a general rule of thumb, if you’re computing the jacobian of an\(R^N \to R^M\)function, and there are many more outputs than inputs (for example,\(M > N\)) thenjacfwd is preferred, otherwise usejacrev. There are exceptions to this rule,but a non-rigorous argument for this follows:
In reverse-mode AD, we are computing the jacobian row-by-row, while inforward-mode AD (which computes Jacobian-vector products), we are computingit column-by-column. The Jacobian matrix has M rows and N columns, so if itis taller or wider one way we may prefer the method that deals with fewerrows or columns.
First, let’s benchmark with more inputs than outputs:
Din=32Dout=2048weight=torch.randn(Dout,Din)bias=torch.randn(Dout)x=torch.randn(Din)# remember the general rule about taller vs wider... here we have a taller matrix:print(weight.shape)using_fwd=Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)",globals=globals())using_bwd=Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)",globals=globals())jacfwd_timing=using_fwd.timeit(500)jacrev_timing=using_bwd.timeit(500)print(f'jacfwd time:{jacfwd_timing}')print(f'jacrev time:{jacrev_timing}')
torch.Size([2048, 32])jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe6fc7445e0>jacfwd(predict, argnums=2)(weight, bias, x) 794.57 us 1 measurement, 500 runs , 1 threadjacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe734f56140>jacrev(predict, argnums=2)(weight, bias, x) 8.38 ms 1 measurement, 500 runs , 1 thread
and then do a relative benchmark:
get_perf(jacfwd_timing,"jacfwd",jacrev_timing,"jacrev",);
Performance delta: 955.1558 percent improvement with jacrev
and now the reverse - more outputs (M) than inputs (N):
Din=2048Dout=32weight=torch.randn(Dout,Din)bias=torch.randn(Dout)x=torch.randn(Din)using_fwd=Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)",globals=globals())using_bwd=Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)",globals=globals())jacfwd_timing=using_fwd.timeit(500)jacrev_timing=using_bwd.timeit(500)print(f'jacfwd time:{jacfwd_timing}')print(f'jacrev time:{jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe734bf68c0>jacfwd(predict, argnums=2)(weight, bias, x) 7.16 ms 1 measurement, 500 runs , 1 threadjacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe730957e80>jacrev(predict, argnums=2)(weight, bias, x) 505.35 us 1 measurement, 500 runs , 1 thread
and a relative performance comparison:
get_perf(jacrev_timing,"jacrev",jacfwd_timing,"jacfwd")
Performance delta: 1316.5774 percent improvement with jacfwd
Hessian computation with functorch.hessian#
We offer a convenience API to compute hessians:torch.func.hessiani.Hessians are the jacobian of the jacobian (or the partial derivative ofthe partial derivative, aka second order).
This suggests that one can just compose functorch jacobian transforms tocompute the Hessian.Indeed, under the hood,hessian(f) is simplyjacfwd(jacrev(f)).
Note: to boost performance: depending on your model, you may also want tousejacfwd(jacfwd(f)) orjacrev(jacrev(f)) instead to compute hessiansleveraging the rule of thumb above regarding wider vs taller matrices.
fromtorch.funcimporthessian# lets reduce the size in order not to overwhelm Colab. Hessians require# significant memory:Din=512Dout=32weight=torch.randn(Dout,Din)bias=torch.randn(Dout)x=torch.randn(Din)hess_api=hessian(predict,argnums=2)(weight,bias,x)hess_fwdfwd=jacfwd(jacfwd(predict,argnums=2),argnums=2)(weight,bias,x)hess_revrev=jacrev(jacrev(predict,argnums=2),argnums=2)(weight,bias,x)
Let’s verify we have the same result regardless of using hessian API orusingjacfwd(jacfwd()).
True
Batch Jacobian and Batch Hessian#
In the above examples we’ve been operating with a single feature vector.In some cases you might want to take the Jacobian of a batch of outputswith respect to a batch of inputs. That is, given a batch of inputs ofshape(B,N) and a function that goes from\(R^N \to R^M\), we would likea Jacobian of shape(B,M,N).
The easiest way to do this is to usevmap:
batch_size=64Din=31Dout=33weight=torch.randn(Dout,Din)print(f"weight shape ={weight.shape}")bias=torch.randn(Dout)x=torch.randn(batch_size,Din)compute_batch_jacobian=vmap(jacrev(predict,argnums=2),in_dims=(None,None,0))batch_jacobian0=compute_batch_jacobian(weight,bias,x)
weight shape = torch.Size([33, 31])
If you have a function that goes from (B, N) -> (B, M) instead and arecertain that each input produces an independent output, then it’s alsosometimes possible to do this without usingvmap by summing the outputsand then computing the Jacobian of that function:
defpredict_with_output_summed(weight,bias,x):returnpredict(weight,bias,x).sum(0)batch_jacobian1=jacrev(predict_with_output_summed,argnums=2)(weight,bias,x).movedim(1,0)asserttorch.allclose(batch_jacobian0,batch_jacobian1)
If you instead have a function that goes from\(R^N \to R^M\) but inputs thatare batched, you composevmap withjacrev to compute batched jacobians:
Finally, batch hessians can be computed similarly. It’s easiest to thinkabout them by usingvmap to batch over hessian computation, but in somecases the sum trick also works.
compute_batch_hessian=vmap(hessian(predict,argnums=2),in_dims=(None,None,0))batch_hess=compute_batch_hessian(weight,bias,x)batch_hess.shape
torch.Size([64, 33, 31, 31])
Computing Hessian-vector products#
The naive way to compute a Hessian-vector product (hvp) is to materializethe full Hessian and perform a dot-product with a vector. We can do better:it turns out we don’t need to materialize the full Hessian to do this. We’llgo through two (of many) different strategies to compute Hessian-vector products:- composing reverse-mode AD with reverse-mode AD- composing reverse-mode AD with forward-mode AD
Composing reverse-mode AD with forward-mode AD (as opposed to reverse-modewith reverse-mode) is generally the more memory efficient way to compute ahvp because forward-mode AD doesn’t need to construct an Autograd graph andsave intermediates for backward:
Here’s some sample usage.
deff(x):returnx.sin().sum()x=torch.randn(2048)tangent=torch.randn(2048)result=hvp(f,(x,),(tangent,))
If PyTorch forward-AD does not have coverage for your operations, then we caninstead compose reverse-mode AD with reverse-mode AD:
Total running time of the script: (0 minutes 10.677 seconds)