Rate this Page

Gradcheck mechanics#

Created On: Apr 27, 2021 | Last Updated On: Jun 18, 2025

This note presents an overview of how thegradcheck() andgradgradcheck() functions work.

It will cover both forward and backward mode AD for both real and complex-valued functions as well as higher-order derivatives.This note also covers both the default behavior of gradcheck as well as the case wherefast_mode=True argument is passed (referred to as fast gradcheck below).

Notations and background information#

Throughout this note, we will use the following convention:

  1. xx,yy,aa,bb,vv,uu,urur anduiui are real-valued vectors andzz is a complex-valued vector that can be rewritten in terms of two real-valued vectors asz=a+ibz = a + i b.

  2. NN andMM are two integers that we will use for the dimension of the input and output space respectively.

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M is our basic real-to-real function such thaty=f(x)y = f(x).

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M is our basic complex-to-real function such thaty=g(z)y = g(z).

For the simple real-to-real case, we write asJfJ_f the Jacobian matrix associated withff of sizeM×NM \times N.This matrix contains all the partial derivatives such that the entry at position(i,j)(i, j) containsyixj\frac{\partial y_i}{\partial x_j}.Backward mode AD is then computing, for a given vectorvv of sizeMM, the quantityvTJfv^T J_f.Forward mode AD on the other hand is computing, for a given vectoruu of sizeNN, the quantityJfuJ_f u.

For functions that contain complex values, the story is a lot more complex. We only provide the gist here and the full description can be found atAutograd for Complex Numbers.

The constraints to satisfy complex differentiability (Cauchy-Riemann equations) are too restrictive for all real-valued loss functions, so we instead opted to use Wirtinger calculus.In a basic setting of Wirtinger calculus, the chain rule requires access to both the Wirtinger derivative (calledWW below) and the Conjugate Wirtinger derivative (calledCWCW below).BothWW andCWCW need to be propagated because in general, despite their name, one is not the complex conjugate of the other.

To avoid having to propagate both values, for backward mode AD, we always work under the assumption that the function whose derivative is being calculated is either a real-valued function or is part of a bigger real-valued function. This assumption means that all the intermediary gradients we compute during the backward pass are also associated with real-valued functions.In practice, this assumption is not restrictive when doing optimization as such problem require real-valued objectives (as there is no natural ordering of the complex numbers).

Under this assumption, usingWW andCWCW definitions, we can show thatW=CWW = CW^* (we use* to denote complex conjugation here) and so only one of the two values actually need to be “backwarded through the graph” as the other one can easily be recovered.To simplify internal computations, PyTorch uses2CW2 * CW as the value it backwards and returns when the user asks for gradients.Similarly to the real case, when the output is actually inRM\mathcal{R}^M, backward mode AD does not compute2CW2 * CW but onlyvT(2CW)v^T (2 * CW) for a given vectorvRMv \in \mathcal{R}^M.

For forward mode AD, we use a similar logic, in this case, assuming that the function is part of a larger function whose input is inR\mathcal{R}. Under this assumption, we can make a similar claim that every intermediary result corresponds to a function whose input is inR\mathcal{R} and in this case, usingWW andCWCW definitions, we can show thatW=CWW = CW for the intermediary functions.To make sure the forward and backward mode compute the same quantities in the elementary case of a one dimensional function, the forward mode also computes2CW2 * CW.Similarly to the real case, when the input is actually inRN\mathcal{R}^N, forward mode AD does not compute2CW2 * CW but only(2CW)u(2 * CW) u for a given vectoruRNu \in \mathcal{R}^N.

Default backward mode gradcheck behavior#

Real-to-real functions#

To test a functionf:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y, we reconstruct the full Jacobian matrixJfJ_f of sizeM×NM \times N in two ways: analytically and numerically.The analytical version uses our backward mode AD while the numerical version uses finite difference.The two reconstructed Jacobian matrices are then compared elementwise for equality.

Default real input numerical evaluation#

If we consider the elementary case of a one-dimensional function (N=M=1N = M = 1), then we can use the basic finite difference formula fromthe wikipedia article. We use the “central difference” for better numerical properties:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

This formula easily generalizes for multiple outputs (M>1M \gt 1) by havingyx\frac{\partial y}{\partial x} be a column vector of sizeM×1M \times 1 likef(x+eps)f(x + eps).In that case, the above formula can be reused as-is and approximates the full Jacobian matrix with only two evaluations of the user function (namelyf(x+eps)f(x + eps) andf(xeps)f(x - eps)).

It is more computationally expensive to handle the case with multiple inputs (N>1N \gt 1). In this scenario, we loop over all the inputs one after the other and apply theepseps perturbation for each element ofxx one after the other. This allows us to reconstruct theJfJ_f matrix column by column.

Default real input analytical evaluation#

For the analytical evaluation, we use the fact, as described above, that backward mode AD computesvTJfv^T J_f.For functions with a single output, we simply usev=1v = 1 to recover the full Jacobian matrix with a single backward pass.

For functions with more than one output, we resort to a for-loop which iterates over the outputs where eachvv is a one-hot vector corresponding to each output one after the other. This allows to reconstruct theJfJ_f matrix row by row.

Complex-to-real functions#

To test a functiong:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y withz=a+ibz = a + i b, we reconstruct the (complex-valued) matrix that contains2CW2 * CW.

Default complex input numerical evaluation#

Consider the elementary case whereN=M=1N = M = 1 first. We know from (chapter 3 of)this research paper that:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

Note thatya\frac{\partial y}{\partial a} andyb\frac{\partial y}{\partial b}, in the above equation, areRR\mathcal{R} \to \mathcal{R} derivatives.To evaluate these numerically, we use the method described above for the real-to-real case.This allows us to compute theCWCW matrix and then multiply it by22.

Note that the code, as of time of writing, computes this value in a slightly convoluted way:

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105# Notation changes in this code block:# s here is y above# x, y here are a, b aboveds_dx=compute_gradient(eps)ds_dy=compute_gradient(eps*1j)# conjugate wirtinger derivativeconj_w_d=0.5*(ds_dx+ds_dy*1j)# wirtinger derivativew_d=0.5*(ds_dx-ds_dy*1j)d[d_idx]=grad_out.conjugate()*conj_w_d+grad_out*w_d.conj()# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

Default complex input analytical evaluation#

Since backward mode AD computes exactly twice theCWCW derivative already, we simply use the same trick as for the real-to-real case here and reconstruct the matrix row by row when there are multiple real outputs.

Functions with complex outputs#

In this case, the user-provided function does not follow the assumption from the autograd that the function we compute backward AD for is real-valued.This means that using autograd directly on this function is not well defined.To solve this, we will replace the test of the functionh:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (whereP\mathcal{P} can be eitherR\mathcal{R} orC\mathcal{C}), with two functions:hrhr andhihi such that:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q))\end{aligned}

whereqPq \in \mathcal{P}.We then do a basic gradcheck for bothhrhr andhihi using either the real-to-real or complex-to-real case described above, depending onP\mathcal{P}.

Note that, the code, as of time of writing, does not create these functions explicitly but perform the chain rule with therealreal orimagimag functions manually by passing thegrad_out\text{grad\_out} arguments to the different functions.Whengrad_out=1\text{grad\_out} = 1, then we are consideringhrhr.Whengrad_out=1j\text{grad\_out} = 1j, then we are consideringhihi.

Fast backward mode gradcheck#

While the above formulation of gradcheck is great, both, to ensure correctness and debuggability, it is very slow because it reconstructs the full Jacobian matrices.This section presents a way to perform gradcheck in a faster way without affecting its correctness.The debuggability can be recovered by adding special logic when we detect an error. In that case, we can run the default version that reconstructs the full matrix to give full details to the user.

The high level strategy here is to find a scalar quantity that can be computed efficiently by both the numerical and analytical methods and that represents the full matrix computed by the slow gradcheck well enough to ensure that it will catch any discrepancy in the Jacobians.

Fast gradcheck for real-to-real functions#

The scalar quantity that we want to compute here isvTJfuv^T J_f u for a given random vectorvRMv \in \mathcal{R}^M and a random unit norm vectoruRNu \in \mathcal{R}^N.

For the numerical evaluation, we can efficiently compute

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

We then perform the dot product between this vector andvv to get the scalar value of interest.

For the analytical version, we can use backward mode AD to computevTJfv^T J_f directly. We then perform the dot product withuu to get the expected value.

Fast gradcheck for complex-to-real functions#

Similar to the real-to-real case, we want to perform a reduction of the full matrix. But the2CW2 * CW matrix is complex-valued and so in this case, we will compare to complex scalars.

Due to some constraints on what we can compute efficiently in the numerical case and to keep the number of numerical evaluations to a minimum, we compute the following (albeit surprising) scalar value:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

wherevRMv \in \mathcal{R}^M,urRNur \in \mathcal{R}^N anduiRNui \in \mathcal{R}^N.

Fast complex input numerical evaluation#

We first consider how to computess with a numerical method. To do so, keeping in mind that we’re consideringg:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y withz=a+ibz = a + i b, and thatCW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}), we rewrite it as follows:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui))\end{aligned}

In this formula, we can see thatyaur\frac{\partial y}{\partial a} ur andybui\frac{\partial y}{\partial b} ui can be evaluated the same way as the fast version for the real-to-real case.Once these real-valued quantities have been computed, we can reconstruct the complex vector on the right side and do a dot product with the real-valuedvv vector.

Fast complex input analytical evaluation#

For the analytical case, things are simpler and we rewrite the formula as:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui\end{aligned}

We can thus use the fact that the backward mode AD provides us with an efficient way to computevT(2CW)v^T (2 * CW) and then perform a dot product of the real part withurur and the imaginary part withuiui before reconstructing the final complex scalarss.

Why not use a complexuu#

At this point, you might be wondering why we did not select a complexuu and just performed the reduction2vTCWu2 * v^T CW u'.To dive into this, in this paragraph, we will use the complex version ofuu notedu=ur+iuiu' = ur' + i ui'.Using such complexuu', the problem is that when doing the numerical evaluation, we would need to compute:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui'\end{aligned}

Which would require four evaluations of real-to-real finite difference (twice as much compared to the approached proposed above).Since this approach does not have more degrees of freedom (same number of real valued variables) and we try to get the fastest possible evaluation here, we use the other formulation above.

Fast gradcheck for functions with complex outputs#

Just like in the slow case, we consider two real-valued functions and use the appropriate rule from above for each function.

Gradgradcheck implementation#

PyTorch also provide a utility to verify second order gradients. The goal here is to make sure that the backward implementation is also properly differentiable and computes the right thing.

This feature is implemented by considering the functionF:x,vvTJfF: x, v \to v^T J_f and use the gradcheck defined above on this function.Note thatvv in this case is just a random vector with the same type asf(x)f(x).

The fast version of gradgradcheck is implemented by using the fast version of gradcheck on that same functionFF.