Custom derivative rules
Contents
Custom derivative rules#
There are two ways to define differentiation rules in JAX:
using
jax.custom_jvpandjax.custom_vjpto define custom differentiation rules for Python functions that are already JAX-transformable; anddefining new
core.Primitiveinstances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This notebook is about #1. To read instead about #2, see thenotebook on adding primitives.
For an introduction to JAX’s automatic differentiation API, seeThe Autodiff Cookbook. This notebook assumes some familiarity withjax.jvp andjax.grad, and the mathematical meaning of JVPs and VJPs.
Summary#
Custom JVPs withjax.custom_jvp#
importjax.numpyasjnpfromjaximportcustom_jvp@custom_jvpdeff(x,y):returnjnp.sin(x)*y@f.defjvpdeff_jvp(primals,tangents):x,y=primalsx_dot,y_dot=tangentsprimal_out=f(x,y)tangent_out=jnp.cos(x)*x_dot*y+jnp.sin(x)*y_dotreturnprimal_out,tangent_out
fromjaximportjvp,gradprint(f(2.,3.))y,y_dot=jvp(f,(2.,3.),(1.,0.))print(y)print(y_dot)print(grad(f)(2.,3.))
2.72789222.7278922-1.2484405-1.2484405
# Equivalent alternative using the defjvps convenience wrapper@custom_jvpdeff(x,y):returnjnp.sin(x)*yf.defjvps(lambdax_dot,primal_out,x,y:jnp.cos(x)*x_dot*y,lambday_dot,primal_out,x,y:jnp.sin(x)*y_dot)
print(f(2.,3.))y,y_dot=jvp(f,(2.,3.),(1.,0.))print(y)print(y_dot)print(grad(f)(2.,3.))
2.72789222.7278922-1.2484405-1.2484405
Custom VJPs withjax.custom_vjp#
fromjaximportcustom_vjp@custom_vjpdeff(x,y):returnjnp.sin(x)*ydeff_fwd(x,y):# Returns primal output and residuals to be used in backward pass by f_bwd.returnf(x,y),(jnp.cos(x),jnp.sin(x),y)deff_bwd(res,g):cos_x,sin_x,y=res# Gets residuals computed in f_fwdreturn(cos_x*g*y,sin_x*g)f.defvjp(f_fwd,f_bwd)
print(grad(f)(2.,3.))
-1.2484405
Example problems#
To get an idea of what problemsjax.custom_jvp andjax.custom_vjp are meant to solve, let’s go over a few examples. A more thorough introduction to thejax.custom_jvp andjax.custom_vjp APIs is in the next section.
Numerical stability#
One application ofjax.custom_jvp is to improve the numerical stability of differentiation.
Say we want to write a function calledlog1pexp, which computes\(x \mapsto \log ( 1 + e^x )\). We can write that usingjax.numpy:
deflog1pexp(x):returnjnp.log(1.+jnp.exp(x))log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
Since it’s written in terms ofjax.numpy, it’s JAX-transformable:
fromjaximportjit,grad,vmapprint(jit(log1pexp)(3.))print(jit(grad(log1pexp))(3.))print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.04858730.95257413[0.5 0.7310586 0.8807971]
But there’s a numerical stability problem lurking here:
print(grad(log1pexp)(100.))
nan
That doesn’t seem right! After all, the derivative of\(x \mapsto \log (1 + e^x)\) is\(x \mapsto \frac{e^x}{1 + e^x}\), and so for large values of\(x\) we’d expect the value to be about 1.
We can get a bit more insight into what’s going on by looking at the jaxpr for the gradient computation:
fromjaximportmake_jaxprmake_jaxpr(grad(log1pexp))(100.)
{lambda; a:f32[].letb:f32[] = exp a c:f32[] = add 1.0:f32[] b _:f32[] = log c d:f32[] = div 1.0:f32[] c e:f32[] = mul d bin(e,) }Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and\(\infty\), respectively, which is never a good idea. That is, we’re effectively evaluatinglambdax:(1/(1+jnp.exp(x)))*jnp.exp(x) for largex, which effectively turns into0.*jnp.inf.
Instead of generating such large and small values, hoping for a cancellation that floats can’t always provide, we’d rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression\(1 - \frac{1}{1 + e^x}\), with no cancellation in sight.
This problem is interesting because even though our definition oflog1pexp could already be JAX-differentiated (and transformed withjit,vmap, …), we’re not happy with the result of applying standard autodiff rules to the primitives comprisinglog1pexp and composing the result. Instead, we’d like to specify how the whole functionlog1pexp should be differentiated, as a unit, and thus arrange those exponentials better.
This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (likejit,vmap, …).
Here’s a solution usingjax.custom_jvp:
fromjaximportcustom_jvp@custom_jvpdeflog1pexp(x):returnjnp.log(1.+jnp.exp(x))@log1pexp.defjvpdeflog1pexp_jvp(primals,tangents):x,=primalsx_dot,=tangentsans=log1pexp(x)ans_dot=(1-1/(1+jnp.exp(x)))*x_dotreturnans,ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))print(jit(grad(log1pexp))(3.))print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.04858730.95257413[0.5 0.7310586 0.8807971]
Here’s adefjvps convenience wrapper to express the same thing:
@custom_jvpdeflog1pexp(x):returnjnp.log(1.+jnp.exp(x))log1pexp.defjvps(lambdat,ans,x:(1-1/(1+jnp.exp(x)))*t)
print(grad(log1pexp)(100.))print(jit(log1pexp)(3.))print(jit(grad(log1pexp))(3.))print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.03.04858730.95257413[0.5 0.7310586 0.8807971]
Enforcing a differentiation convention#
A related application is to enforce a differentiation convention, perhaps at a boundary.
Consider the function\(f : \mathbb{R}_+ \to \mathbb{R}_+\) with\(f(x) = \frac{x}{1 + \sqrt{x}}\), where we take\(\mathbb{R}_+ = [0, \infty)\). We might implement\(f\) as a program like this:
deff(x):returnx/(1+jnp.sqrt(x))
As a mathematical function on\(\mathbb{R}\) (the full real line),\(f\) is not differentiable at zero (because the limit defining the derivative doesn’t exist from the left). Correspondingly, autodiff produces anan value:
print(grad(f)(0.))
nan
But mathematically if we think of\(f\) as a function on\(\mathbb{R}_+\) then it is differentiable at 0 [Rudin’s Principles of Mathematical Analysis Definition 5.1, or Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python functiongrad(f) to return at0.0, namely1.0. By default, JAX’s machinery for differentiation assumes all functions are defined over\(\mathbb{R}\) and thus doesn’t produce1.0 here.
We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function\(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) on\(\mathbb{R}_+\),
@custom_jvpdeff(x):returnx/(1+jnp.sqrt(x))@f.defjvpdeff_jvp(primals,tangents):x,=primalsx_dot,=tangentsans=f(x)ans_dot=((jnp.sqrt(x)+2)/(2*(jnp.sqrt(x)+1)**2))*x_dotreturnans,ans_dot
print(grad(f)(0.))
1.0
Here’s the convenience wrapper version:
@custom_jvpdeff(x):returnx/(1+jnp.sqrt(x))f.defjvps(lambdat,ans,x:((jnp.sqrt(x)+2)/(2*(jnp.sqrt(x)+1)**2))*t)
print(grad(f)(0.))
1.0
Gradient clipping#
While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.
For gradient clipping, we can usejnp.clip together with ajax.custom_vjp reverse-mode-only rule:
fromfunctoolsimportpartialfromjaximportcustom_vjp@custom_vjpdefclip_gradient(lo,hi,x):returnx# identity functiondefclip_gradient_fwd(lo,hi,x):returnx,(lo,hi)# save bounds as residualsdefclip_gradient_bwd(res,g):lo,hi=resreturn(None,None,jnp.clip(g,lo,hi))# use None to indicate zero cotangents for lo and hiclip_gradient.defvjp(clip_gradient_fwd,clip_gradient_bwd)
importmatplotlib.pyplotaspltfromjaximportvmapt=jnp.linspace(0,10,1000)plt.plot(jnp.sin(t))plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7b504612dca0>]

defclip_sin(x):x=clip_gradient(-0.75,0.75,x)returnjnp.sin(x)plt.plot(clip_sin(t))plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7b505c162480>]

Python debugging#
Another application that is motivated by development workflow rather than numerics is to set apdb debugger trace in the backward pass of reverse-mode autodiff.
When trying to track down the source of anan runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that withjax.custom_vjp.
We’ll defer an example until the next section.
Implicit function differentiation of iterative implementations#
This example gets pretty deep in the mathematical weeds!
Another application forjax.custom_vjp is reverse-mode differentiation of functions that are JAX-transformable (byjit,vmap, …) but not efficiently JAX-differentiable for some reason, perhaps because they involvelax.while_loop. (It’s not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn’t possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)
For example, consider thisfixed_point routine which computes a fixed point by iteratively applying a function in awhile_loop:
fromjax.laximportwhile_loopdeffixed_point(f,a,x_guess):defcond_fun(carry):x_prev,x=carryreturnjnp.abs(x_prev-x)>1e-6defbody_fun(carry):_,x=carryreturnx,f(a,x)_,x_star=while_loop(cond_fun,body_fun,(x_guess,f(a,x_guess)))returnx_star
This is an iterative procedure for numerically solving the equation\(x = f(a, x)\) for\(x\), by iterating\(x_{t+1} = f(a, x_t)\) until\(x_{t+1}\) is sufficiently close to\(x_t\). The result\(x^*\) depends on the parameters\(a\), and so we can think of there being a function\(a \mapsto x^*(a)\) that is implicitly defined by equation\(x = f(a, x)\).
We can usefixed_point to run iterative procedures to convergence, for example running Newton’s method to calculate square roots while only executing adds, multiplies, and divides:
defnewton_sqrt(a):update=lambdaa,x:0.5*(x+a/x)returnfixed_point(update,a,a)
print(newton_sqrt(2.))
1.4142135
We canvmap orjit the function as well:
print(jit(vmap(newton_sqrt))(jnp.array([1.,2.,3.,4.])))
[1. 1.4142135 1.7320509 2. ]
We can’t apply reverse-mode automatic differentiation because of thewhile_loop, but it turns out we wouldn’t want to anyway: instead of differentiating through the implementation offixed_point and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas’s Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we’re about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.
Consider again the equation\(x = f(a, x)\) and the function\(x^*\). We want to evaluate vector-Jacobian products like\(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\).
At least in an open neighborhood around the point\(a_0\) at which we want to differentiate, let’s assume that the equation\(x^*(a) = f(a, x^*(a))\) holds for all\(a\). Since the two sides are equal as functions of\(a\), their derivatives must be equal as well, so let’s differentiate both sides:
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
Setting\(A = \partial_1 f(a_0, x^*(a_0))\) and\(B = \partial_0 f(a_0, x^*(a_0))\), we can write the quantity we’re after more simply as
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
or, by rearranging,
\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).
That means we can evaluate vector-Jacobian products like
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),
where\(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\), or equivalently\(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\), or equivalently\(w^\mathsf{T}\) is the fixed point of the map\(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\). That last characterization gives us a way to write the VJP forfixed_point in terms of a call tofixed_point! Moreover, after expanding\(A\) and\(B\) back out, we can see we need only to evaluate VJPs of\(f\) at\((a_0, x^*(a_0))\).
Here’s the upshot:
fromjaximportvjp@partial(custom_vjp,nondiff_argnums=(0,))deffixed_point(f,a,x_guess):defcond_fun(carry):x_prev,x=carryreturnjnp.abs(x_prev-x)>1e-6defbody_fun(carry):_,x=carryreturnx,f(a,x)_,x_star=while_loop(cond_fun,body_fun,(x_guess,f(a,x_guess)))returnx_stardeffixed_point_fwd(f,a,x_init):x_star=fixed_point(f,a,x_init)returnx_star,(a,x_star)deffixed_point_rev(f,res,x_star_bar):a,x_star=res_,vjp_a=vjp(lambdaa:f(a,x_star),a)a_bar,=vjp_a(fixed_point(partial(rev_iter,f),(a,x_star,x_star_bar),x_star_bar))returna_bar,jnp.zeros_like(x_star)defrev_iter(f,packed,u):a,x_star,x_star_bar=packed_,vjp_x=vjp(lambdax:f(a,x),x_star)returnx_star_bar+vjp_x(u)[0]fixed_point.defvjp(fixed_point_fwd,fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))print(grad(grad(newton_sqrt))(2.))
0.35355338-0.088388346
We can check our answers by differentiatingjnp.sqrt, which uses a totally different implementation:
print(grad(jnp.sqrt)(2.))print(grad(grad(jnp.sqrt))(2.))
0.35355338-0.08838835
A limitation to this approach is that the argumentf can’t close over any values involved in differentiation. That is, you might notice that we kept the parametera explicit in the argument list offixed_point. For this use case, consider using the low-level primitivelax.custom_root, which allows for deriviatives in closed-over variables with custom root-finding functions.
Basic usage ofjax.custom_jvp andjax.custom_vjp APIs#
Usejax.custom_jvp to define forward-mode (and, indirectly, reverse-mode) rules#
Here’s a canonical basic example of usingjax.custom_jvp, where the comments useHaskell-like type signatures:
fromjaximportcustom_jvpimportjax.numpyasjnp# f :: a -> b@custom_jvpdeff(x):returnjnp.sin(x)# f_jvp :: (a, T a) -> (b, T b)deff_jvp(primals,tangents):x,=primalst,=tangentsreturnf(x),jnp.cos(x)*tf.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
fromjaximportjvpprint(f(3.))y,y_dot=jvp(f,(3.,),(1.,))print(y)print(y_dot)
0.141120.14112-0.9899925
In words, we start with a primal functionf that takes inputs of typea and produces outputs of typeb. We associate with it a JVP rule functionf_jvp that takes a pair of inputs representing the primal inputs of typea and the corresponding tangent inputs of typeTa, and produces a pair of outputs representing the primal outputs of typeb and tangent outputs of typeTb. The tangent outputs should be a linear function of the tangent inputs.
You can also usef.defjvp as a decorator, as in
@custom_jvpdeff(x):...@f.defjvpdeff_jvp(primals,tangents):...
Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation onf. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:
fromjaximportgradprint(grad(f)(3.))print(grad(grad(f))(3.))
-0.9899925-0.14112
For automatic transposition to work, the JVP rule’s output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.
Multiple arguments work like this:
@custom_jvpdeff(x,y):returnx**2*y@f.defjvpdeff_jvp(primals,tangents):x,y=primalsx_dot,y_dot=tangentsprimal_out=f(x,y)tangent_out=2*x*y*x_dot+x**2*y_dotreturnprimal_out,tangent_out
print(grad(f)(2.,3.))
12.0
Thedefjvps convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
@custom_jvpdeff(x):returnjnp.sin(x)f.defjvps(lambdat,ans,x:jnp.cos(x)*t)
print(grad(f)(3.))
-0.9899925
Here’s adefjvps example with multiple arguments:
@custom_jvpdeff(x,y):returnx**2*yf.defjvps(lambdax_dot,primal_out,x,y:2*x*y*x_dot,lambday_dot,primal_out,x,y:x**2*y_dot)
print(grad(f)(2.,3.))print(grad(f,0)(2.,3.))# same as aboveprint(grad(f,1)(2.,3.))
12.012.04.0
As a shorthand, withdefjvps you can pass aNone value to indicate that the JVP for a particular argument is zero:
@custom_jvpdeff(x,y):returnx**2*yf.defjvps(lambdax_dot,primal_out,x,y:2*x*y*x_dot,None)
print(grad(f)(2.,3.))print(grad(f,0)(2.,3.))# same as aboveprint(grad(f,1)(2.,3.))
12.012.00.0
Calling ajax.custom_jvp function with keyword arguments, or writing ajax.custom_jvp function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard libraryinspect.signature mechanism.
When you’re not performing differentiation, the functionf is called just as if it weren’t decorated byjax.custom_jvp:
@custom_jvpdeff(x):print('called f!')# a harmless side-effectreturnjnp.sin(x)@f.defjvpdeff_jvp(primals,tangents):print('called f_jvp!')# a harmless side-effectx,=primalst,=tangentsreturnf(x),jnp.cos(x)*t
fromjaximportvmap,jitprint(f(3.))
called f!0.14112
print(vmap(f)(jnp.arange(3.)))print(jit(f)(3.))
called f![0. 0.84147096 0.9092974 ]called f!0.14112
The custom JVP rule is invoked during differentiation, whether forward or reverse:
y,y_dot=jvp(f,(3.,),(1.,))print(y_dot)
called f_jvp!called f!-0.9899925
print(grad(f)(3.))
called f_jvp!called f!-0.9899925
Notice thatf_jvp callsf to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the originalf to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can’t make use of intermediate values from the evaluation off in our ruleand also have the rule apply in all orders of higher-order differentiation.)
grad(grad(f))(3.)
called f_jvp!called f_jvp!called f!
Array(-0.14112, dtype=float32, weak_type=True)
You can use Python control flow withjax.custom_jvp:
@custom_jvpdeff(x):ifx>0:returnjnp.sin(x)else:returnjnp.cos(x)@f.defjvpdeff_jvp(primals,tangents):x,=primalsx_dot,=tangentsans=f(x)ifx>0:returnans,2*x_dotelse:returnans,3*x_dot
print(grad(f)(1.))print(grad(f)(-1.))
2.03.0
Usejax.custom_vjp to define custom reverse-mode-only rules#
Whilejax.custom_jvp suffices for controlling both forward- and, via JAX’s automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that withjax.custom_vjp:
fromjaximportcustom_vjpimportjax.numpyasjnp# f :: a -> b@custom_vjpdeff(x):returnjnp.sin(x)# f_fwd :: a -> (b, c)deff_fwd(x):returnf(x),jnp.cos(x)# f_bwd :: (c, CT b) -> CT adeff_bwd(cos_x,y_bar):return(cos_x*y_bar,)f.defvjp(f_fwd,f_bwd)
fromjaximportgradprint(f(3.))print(grad(f)(3.))
0.14112-0.9899925
In words, we again start with a primal functionf that takes inputs of typea and produces outputs of typeb. We associate with it two functions,f_fwd andf_bwd, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.
The functionf_fwd describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal functionf, in that it takes a primal input of typea. But as output it produces a pair, where the first element is the primal outputb and the second element is any “residual” data of typec to be stored for use by the backward pass. (This second output is analogous toPyTorch’s save_for_backward mechanism.)
The functionf_bwd describes the backward pass. It takes two inputs, where the first is the residual data of typec produced byf_fwd and the second is the output cotangents of typeCTb corresponding to the output of the primal function. It produces an output of typeCTa representing the cotangents corresponding to the input of the primal function. In particular, the output off_bwd must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function.
So multiple arguments work like this:
fromjaximportcustom_vjp@custom_vjpdeff(x,y):returnjnp.sin(x)*ydeff_fwd(x,y):returnf(x,y),(jnp.cos(x),jnp.sin(x),y)deff_bwd(res,g):cos_x,sin_x,y=resreturn(cos_x*g*y,sin_x*g)f.defvjp(f_fwd,f_bwd)
print(grad(f)(2.,3.))
-1.2484405
Calling ajax.custom_vjp function with keyword arguments, or writing ajax.custom_vjp function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard libraryinspect.signature mechanism.
As withjax.custom_jvp, the custom VJP rule comprised byf_fwd andf_bwd is not invoked if differentiation is not applied. If function is evaluated, or transformed withjit,vmap, or other non-differentiation transformations, then onlyf is called.
@custom_vjpdeff(x):print("called f!")returnjnp.sin(x)deff_fwd(x):print("called f_fwd!")returnf(x),jnp.cos(x)deff_bwd(cos_x,y_bar):print("called f_bwd!")return(cos_x*y_bar,)f.defvjp(f_fwd,f_bwd)
print(f(3.))
called f!0.14112
print(grad(f)(3.))
called f_fwd!called f!called f_bwd!-0.9899925
y,f_vjp=vjp(f,3.)print(y)
called f_fwd!called f!0.14112
print(f_vjp(1.))
called f_bwd!(Array(-0.9899925, dtype=float32, weak_type=True),)
Forward-mode autodiff cannot be used on thejax.custom_vjpfunction and will raise an error:
fromjaximportjvptry:jvp(f,(3.,),(1.,))exceptTypeErrorase:print('ERROR!{}'.format(e))
called f_fwd!called f!ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
If you want to use both forward- and reverse-mode, usejax.custom_jvp instead.
We can usejax.custom_vjp together withpdb to insert a debugger trace in the backward pass:
importpdb@custom_vjpdefdebug(x):returnx# acts like identitydefdebug_fwd(x):returnx,xdefdebug_bwd(x,g):pdb.set_trace()returngdebug.defvjp(debug_fwd,debug_bwd)
deffoo(x):y=x**2y=debug(y)# insert pdb in corresponding backward pass stepreturnjnp.sin(y)
jax.grad(foo)(3.)><ipython-input-113-b19a2dc1abf7>(12)debug_bwd()->returng(Pdb)pxArray(9.,dtype=float32)(Pdb)pgArray(-0.91113025,dtype=float32)(Pdb)q
More features and details#
Working withlist /tuple /dict containers (and other pytrees)#
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, anypytrees are permissible, so long as their structures are consistent according to the type constraints.
Here’s a contrived example withjax.custom_jvp:
fromcollectionsimportnamedtuplePoint=namedtuple("Point",["x","y"])@custom_jvpdeff(pt):x,y=pt.x,pt.yreturn{'a':x**2,'b':(jnp.sin(x),jnp.cos(y))}@f.defjvpdeff_jvp(primals,tangents):pt,=primalspt_dot,=tangentsans=f(pt)ans_dot={'a':2*pt.x*pt_dot.x,'b':(jnp.cos(pt.x)*pt_dot.x,-jnp.sin(pt.y)*pt_dot.y)}returnans,ans_dotdeffun(pt):dct=f(pt)returndct['a']+dct['b'][0]
pt=Point(1.,2.)print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
And an analogous contrived example withjax.custom_vjp:
@custom_vjpdeff(pt):x,y=pt.x,pt.yreturn{'a':x**2,'b':(jnp.sin(x),jnp.cos(y))}deff_fwd(pt):returnf(pt),ptdeff_bwd(pt,g):a_bar,(b0_bar,b1_bar)=g['a'],g['b']x_bar=2*pt.x*a_bar+jnp.cos(pt.x)*b0_bary_bar=-jnp.sin(pt.y)*b1_barreturn(Point(x_bar,y_bar),)f.defvjp(f_fwd,f_bwd)deffun(pt):dct=f(pt)returndct['a']+dct['b'][0]
pt=Point(1.,2.)print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
Handling non-differentiable arguments#
Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case offixed_point, the function argumentf was such a non-differentiable argument. A similar situation arises withjax.experimental.odeint.
jax.custom_jvp withnondiff_argnums#
Use the optionalnondiff_argnums parameter tojax.custom_jvp to indicate arguments like these. Here’s an example withjax.custom_jvp:
fromfunctoolsimportpartial@partial(custom_jvp,nondiff_argnums=(0,))defapp(f,x):returnf(x)@app.defjvpdefapp_jvp(f,primals,tangents):x,=primalsx_dot,=tangentsreturnf(x),2.*x_dot
print(app(lambdax:x**3,3.))
27.0
print(grad(app,1)(lambdax:x**3,3.))
2.0
Notice the gotcha here: no matter where in the argument list these parameters appear, they’re placed at thestart of the signature of the corresponding JVP rule. Here’s another example:
@partial(custom_jvp,nondiff_argnums=(0,2))defapp2(f,x,g):returnf(g((x)))@app2.defjvpdefapp2_jvp(f,g,primals,tangents):x,=primalsx_dot,=tangentsreturnf(g(x)),3.*x_dot
print(app2(lambdax:x**3,3.,lambday:5*y))
3375.0
print(grad(app2,1)(lambdax:x**3,3.,lambday:5*y))
3.0
jax.custom_vjp withnondiff_argnums#
A similar option exists forjax.custom_vjp, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the_bwd rule, no matter where they appear in the signature of the original function. The signature of the_fwd rule remains unchanged - it is the same as the signature of the primal function. Here’s an example:
@partial(custom_vjp,nondiff_argnums=(0,))defapp(f,x):returnf(x)defapp_fwd(f,x):returnf(x),xdefapp_bwd(f,x,g):return(5*g,)app.defvjp(app_fwd,app_bwd)
print(app(lambdax:x**2,4.))
16.0
print(grad(app,1)(lambdax:x**2,4.))
5.0
Seefixed_point above for another usage example.
You don’t need to usenondiff_argnumswith array-valued arguments, for example ones with integer dtype. Instead,nondiff_argnums should only be used for argument values that don’t correspond to JAX types (essentially don’t correspond to array types), like Python callables or strings. If JAX detects that an argument indicated bynondiff_argnums contains a JAX Tracer, then an error is raised. Theclip_gradient function above is a good example of not usingnondiff_argnums for integer-dtype array arguments.
