Custom JVP/VJP rules for JAX-transformable functions
Contents
Custom JVP/VJP rules for JAX-transformable functions#
This is a design document, explaining some of the thinking behind the design andimplementation ofjax.custom_jvp andjax.custom_vjp. For user-orienteddocumentation, seethe tutorial notebook.
There are two ways to define differentiation rules in JAX:
using
jax.custom_jvpandjax.custom_vjpto define custom differentiationrules for Python functions that are already JAX-transformable; anddefining new
core.Primitiveinstances along with all their transformationrules, for example to call into functions from other systems like solvers,simulators, or general numerical computing systems.
This document is about #1 only.
Contents#
Goals#
We wantusers to customize the forward- and/or reverse-mode differentiationbehavior of their code. This customization
should have aclear and consistent semantics in how it works and how itcomposes with other JAX transformations; and
should beflexible in supporting use cases and workflows like inAutograd andPyTorch, including cases involving differentiation ofPython control flow and workflows for NaN debugging.
AsJAX developers we want to write library functions, likelogitandexpit,that are defined in terms of other primitives, but for the purposes ofdifferentiation have primitive-like behavior in the sense that we want to definecustom differentiation rules for them, which may be more numerically stable orperformant. In particular, we don’t want to have to specifyvmap orjitrules for functions likelogit andexpit.
As a stretch goal, we’d like to make JAX a great environment for power userslooking to add custom differentiation rules for higher-order functions likefixed_point,odeint, etc.; this design doc won’t solve that problem, but wewant to be confident we’re not going to preclude good solutions to that problem.
That is, our primary goals are
solve the vmap-removes-custom-jvp semantics problem (#1249), and
allow Python in custom VJPs, e.g. to debug NaNs(#1275).
Secondary goals are3. clean up and simplify user experience (symbolic zeros, kwargs, etc)4. make progress towards a world where users can easily addfixed_point,odeint,root, etc.
Overall, we want to close#116,#1097,#1249,#1275,#1366,#1723,#1670,#1875,#1938,and replace the custom_transforms machinery (from#636,#818,and others).
Non-goals#
Here are objectives we’renot aiming to achieve:
The
custom_transformsmachinery aimed to provide a transformation-genericmechanism for customizing behavior, in principle (though never really used inpractice) allowing users to customize rules for any transformation whilesomehow inheriting the “transparent” behavior for others.We are insteadonly going to solve the customization problem for differentiation (JVP andVJP, separately). Differentiation is the only case actually requested, andby specializing to differentiation we can reduce complexity and improveflexibility. To control all rules one can just write a primitive.We’re not going to prioritize mathematical aesthetics over flexibilityand clarity on the user side, and simplicity on the implementation side. Inparticular, while the custom VJP signature
a->(b,CTb--oCTa)ismathematically pleasing, if it’s hard to implement in a Python mechanismbecause of the closure in the return type, we’re fine doing something thathandles residuals more explicitly.Serialization support, of the form where the staged-out serializedprogram representation can be loaded and further JAX-transformed as opposedto just evaluated, is currently out of scope for these custom JVP/VJPtransformation rules. Serialization may be useful not only for researcherswho want to save some representation of their computation (and transform itafter loading it), but also for future considerations like having jaxprtransformations implemented outside Python, or having jaxprs as an MLIRdialect. By defining this as a non-goal for the purpose of this design, wehave fewer constraints on where we can stash Python callables.
Main problem descriptions#
The vmap-removes-custom-jvp semantics problem#
The vmap-removes-custom-jvp semantics problem is that vmap does not composeproperly with differentiation of functions withcustom_transforms rules:
# old custom_transforms api to be replaced@jax.custom_transformsdeff(x):return2.*x# f_vjp :: a -> (b, CT b --o CT a)deff_vjp(x):returnf(x),lambdag:3.*x# 3 instead of 2jax.defvjp_all(f,f_vjp)grad(f)(1.)# 3.vmap(grad(f))(np.ones(4))# [3., 3., 3., 3.]grad(lambdax:vmap(f)(x).sum())(np.ones(4))# [2., 2., 2., 2.]
The last grad-of-vmap line has an unexpected result! In general, applyingvmap, or really any non-differentiation transformation, has the effect ofremoving the custom differentiation rule. (Applyingjvp causes a failure whena custom VJP rule is defined.)
The problem exists because transformations are like rewrites, and thevmaptransformation effectively rewrites the function to no longer call thenewly-introduced primitive for which there is a custom rule (and hencegradthen doesn’t produce the custom rule’s result). In more detail, thecustom_transforms machinery sets things up so that evaluatingf(x) appliesthe function
{lambda;;a.letb=f_primitiveain[b]}
wheref_primitive is a new primitive (introduced for everycustom_transformsfunction and in fact for every call of the function) to which the custom VJPrule is associated. When we evaluategrad(f)(x), the differentiation machineryencountersf_primitive and processes it with the custom rule.
However, becausef_primitive istransparent tovmap, in the sense thatvmap operates on (effectively by inlining) the definition off_primitive,the functionvmap(f) is effectively
{lambda;;a.letb=mul2.ain[b]}
In words,vmap rewrites the function in terms of its underlying primitives andtheir transformation rules, removingf_primitive entirely.
More generally,becausevmap(f) has semantics defined in terms of calls tof, it is semantically inconsistent to remove the custom derivative rule. Thatis, since we define
vmap(f)(xs)==np.stack([f(x)forxinxs])
we must have
jvp(vmap(f))(xs)==jvp(lambdaxs:np.stack([f(x)forxinxs]))
yet this property is not observed whenf has a custom derivative rule defined,as the custom derivative rule is used in the right-hand version but not theleft-hand one.
This issue isn’t specific tovmap; it applies to all transformations for whichthe semantics of transforming a functionf are defined in terms of calls tothe functionf, rather than rewriting it into another function. Themasktransformation also falls into this class. Differentiation transforms and thehypothetical all-unary-functions-become-cosine transform are not in this class.
(The interaction between additional custom rules, like customvmap rules, islikely to get even more complex, suggesting the problem framing ofcustom_transforms is too broad.)
The Python flexibility problem#
In JAX, as inAutograd andPyTorch but not TF1, differentiation of a Python functionis performed while the function is being executed and traced. This behaviordelights users for a few reasons.
First and most importantly, it enables pdb-based workflows, e.g. forinspecting numerics or catching NaNs. That is, users can employ the standardPython debugger and other Python-native tools to debug their code, even beingable to inspect runtime values to understand numerical behavior on examples andto catch fundamentally runtime errors like NaNs. In fact, just while working onthe PR corresponding to this design, especially on theodeint primitive, Iused runtime value inspection to debug issues many times, increasing myconfidence that this is a key user workflow in Python. One especially handytrick, which I’ve used in both JAX and Autograd many times, is the ability toinsert a debugger breakpoint in a custom VJP rule to enter a debugger at aspecific point in the backward pass.
Second, it allows differentiation of Python native control flow. We’re notsure how often this is used in practice in finalized software artifacts, butwhen users first poke around JAX or Autograd they’re often impressed by thisfreedom. There’s a reason we include it at the top of our JAX and AutogradREADMEs, slide decks, and demos. Ceding this capability would be a step backwardfrom Autograd. We want JAX to have the best automatic differentiation.
However, thecustom_transforms machinery does not provide this Python-supportflexibility. That is, because it’s implemented in terms of up-front jaxprformation from the Python code for both the user function and customdifferentiation rules, code like this leads to an abstract value tracing error:
# old custom_transforms api to be replaced@jax.custom_transformsdeff(x):ifx>0:returnxelse:return0.deff_vjp(x):return...jax.defvjp_all(f,f_vjp)grad(f)(1.)# Error!
Solution idea#
The main idea is thatdougalm@ already solvedthese problems withcore.call. That is, we can frame the task of specifyinga custom JVP rule for a user function in terms of a new Python-level callprimitive (not to be added to the jaxpr language; see below). This new callprimitive has a user Python function associated with it just likecore.call,but additionally has a second Python callable representing the JVP rule. Let’srefer to this new call primitive ascustom_jvp_call.
Transformations likevmap interact withcustom_jvp_call as withcore.call:they effectively pass right through it and are applied to the underlying Pythoncallables. Schematically, writing in terms of curried versions of the primitivesfor convenience, analogously to howvmap interacts withcore.call byapplying to the function to be called:
vmap(call(f))==call(vmap(f))
for the new primitivecustom_jvp_call we simply applyvmap to the twofunctions it entails:
vmap(custom_jvp_call(f,f_jvp))==custom_jvp_call(vmap(f),vmap(f_jvp))
This behavior means we’ve solved thevmap-removes-custom-jvp semanticsproblem.
Thejvp transformation interacts as one might expect: it just callsf_jvp,
jvp(call(f))==call(jvp(f))jvp(custom_jvp_call(f,f_jvp))==f_jvp
Becausecustom_jvp_call acts likecore.call (and not likexla.xla_call) inthat it doesn’t raise the abstraction level of its inputs (because it’s notdelaying anything or staging anything out), it means we’ve solvedthe Pythonflexibility problem: there are no constraintson the user Python function (above the usual functional programming constraintsrequired byjvp orvjp).
What about evaluation and compilation? These are two ways to “exit” the JAXsystem, in the sense that no additional transformations can be applied afterthese steps. As a result, their rules are trivial:
eval(call(f))==eval(f)jit(call(f))==hlo_call(jit(f))eval(custom_jvp_call(f,f_jvp))==eval(f)jit(custom_jvp_call(f,f_jvp))==hlo_call(jit(f))
In words, if a JVP rule hasn’t already rewrittencustom_jvp_call(f,f_jvp)intof_jvp, when we get to the point of evaluation witheval or staging outto XLA withjit, differentiation is never going to be applied, so we justignoref_jvp and behave just likecore.call. However, due to the wrinklediscussed next, the partial eval rule forcustom_jvp_call must be a bit morecomplex, since partial evaluation isn’t just used to stage out to XLA withjit.
The only remaining wrinkle has to do with “initial-style” jaxpr-formingprimitives, likelax.scan, and their transformation rules. These represent adifferent kind of “staging out to a jaxpr” than that for compilation because wecan perform additional transformations on the staged-out jaxpr. That is, whenlax.scan forms a jaxpr, it does not exit the transformation system, since whenwe apply a jvp or vmap to alax.scan we need to apply it to the functionrepresented by the jaxpr.
Another way to state the wrinkle is that initial-style primitives likelax.scanrely on the ability to round-trip to a jaxpr and back to a Python callable whilepreserving semantics. That must mean preserving custom differentiation rulesemantics too.
The solution is to use a bit of dynamic scoping: when we’re staging out to ajaxpr for an initial-style primitive, like those in lax_control_flow.py, we seta bit on the global trace state. When that bit is set, instead of using thefinal-stylecustom_jvp_call primitive, we use an initial-stylecustom_jvp_call_jaxpr primitive, and trace the functionsf andf_jvp tojaxprs up-front to make initial-style processing easier. Thecustom_jvp_call_jaxpr primitive is otherwise similar to the final-styleversion.
(Footnote: while morally we form jaxprs for bothf andf_jvp before bindingcustom_jvp_call_jaxpr, we need to delay the formation of the jaxpr off_jvpbecause it may call the custom-JVP function and thus eager processing would leadto an infinite recursion. We delay that jaxpr formation in a thunk.)
If we gave up onthe Python flexibilityproblem, we could get away with only havingcustom_jvp_call_jaxpr and not having the separate Python-level primitivecustom_jvp_call.
API#
The custom JVP for ana->b function is specified with an(a,Ta)->(b,Tb) function:
# f :: a -> b@jax.custom_jvpdeff(x):returnnp.sin(x)# f_jvp :: (a, T a) -> (b, T b)deff_jvp(primals,tangents):x,=primalst,=tangentsreturnf(x),np.cos(x)*tf.defjvp(f_jvp)
(Interesting autodiff aside: for the rule to apply to higher-orderdifferentiation, one must callf in the body off_jvp; that precludes somekinds of work sharing between the internals off and the tangent calculation.)
The custom VJP for ana->b function is specified with ana->(b,c) forwardpass function paired with a(c,CTb)->CT a backward pass function:
# f :: a -> b@jax.custom_vjpdeff(x):returnnp.sin(x)# f_fwd :: a -> (b, c)deff_fwd(x):returnf(x),np.cos(x)# f_bwd :: (c, CT b) -> CT adeff_bwd(cos_x,g):return(cos_x*g,)f.defvjp(f_fwd,f_bwd)
The signaturea->(b,CTb--oCTa) is more aesthetically pleasing, butsupporting it would make the implementation more complex and might requirecompromising expressibility desiderata. The basic reason that Python callablesare opaque (unless we trace them to a jaxpr eagerly, which places expressivenessconstraints), and in this case we may be returning a callable withvmap tracersinside its closure that we need to know about during the forward pass.
We could add convenience wrappers, for example to define the JVP rule for asingle argument at a time (like we do internally for primitives). But becausethis proposal is complicated enough as it is, I decided against conveniencelayers; let’s keep things minimal for now.
There are some other bells and whistles to the API:
Inputs and output types
a,b, andccan be arbitrary pytrees ofjaxtypes.Passing arguments by name (keyword arguments) is supported when they can beresolved to positions using the
inspectmodule. This is a bit of an experimentwith Python 3’s improved ability to programmatically inspect argumentsignatures. I believe it is sound but not complete, which is a fine place to be.(See also#2069.)Arguments can be marked non-differentiable using
nondiff_argnums, and as withjit’sstatic_argnumsthese arguments don’t have to be JAX types. We need toset a convention for how these arguments are passed to the rules. For a primalfunction with type signature(d,a)->bwheredrepresents thenon-differentiable type, the JVP rule’s signature is(a,Ta,d)->Tbandthe VJP rule’s reverse component signature is(d,c,CTb)->CTa. That is,the non-differentiable arguments are passed in order afterprimalsandtangentsfor a custom JVP rule, and passed in order preceding the residuals ina custom VJP rule’s reverse function.
Implementation notes#
Updated
jax.experimental.odeintSince
odeintis a pretty complex user of a custom VJP rule, in addition tojust updating it to work at all, I wanted to revise it to be a canonicaluser of the new custom VJP API as a way to test that the API was a good one.Along the way I made other improvements to the
odeintimplementation:remove raveling/unraveling boilerplate
make use of
lax.scanto remove the index-update logicspeed up by 20+% on the simple pendulum benchmark
Added a custom bind method on each transform for the custom derivative callprimitives,
custom_jvp_callandcustom_vjp_call. It’s likecore.call_bind, except we don’t process env traces: those are just errors.Added
custom_linprimitive, which gets staged out into linear jaxprs to betransposed when using a custom VJP rule.Because our reverse-mode autodiff is decomposed into linearization, partialevaluation, and transposition, our custom VJP rules are processed in twoseparate steps: one during linearization and one during transposition.
The linearization step, i.e. the JVP rule for
custom_vjp_call, appliescustom_linto the tangent values;custom_lincarries with it the user’scustom backward-pass function, and as a primitive it only has a transposerule.This mechanism is described more in#636.
To prevent
