custom_vjp and nondiff_argnums update guide
Contents
custom_vjp andnondiff_argnums update guide#
mattjj@Oct 14 2020
This doc assumes familiarity withjax.custom_vjp, as described in theCustomderivative rules for JAX-transformable Pythonfunctionsnotebook.
What to update#
After JAXPR #4008, the argumentspassed into acustom_vjp function’snondiff_argnums can’t beTracers (orcontainers ofTracers), which basically means to allow forarbitrarily-transformable codenondiff_argnums shouldn’t be used forarray-valued arguments. Instead,nondiff_argnums should be used only fornon-array values, like Python callables or shape tuples or strings.
Wherever we used to usenondiff_argnums for array values, we should just passthose as regular arguments. In thebwd rule, we need to produce values for them,but we can just produceNone values to indicate there’s no correspondinggradient value.
For example, here’s theold way to writeclip_gradient, which won’t workwhenhi and/orlo areTracers from some JAX transformation.
fromfunctoolsimportpartialimportjax@partial(jax.custom_vjp,nondiff_argnums=(0,1))defclip_gradient(lo,hi,x):returnx# identity functiondefclip_gradient_fwd(lo,hi,x):returnx,None# no residual values to savedefclip_gradient_bwd(lo,hi,_,g):return(jnp.clip(g,lo,hi),)clip_gradient.defvjp(clip_gradient_fwd,clip_gradient_bwd)
Here’s thenew, awesome way, which supports arbitrary transformations:
importjax@jax.custom_vjp# no nondiff_argnums!defclip_gradient(lo,hi,x):returnx# identity functiondefclip_gradient_fwd(lo,hi,x):returnx,(lo,hi)# save lo and hi values as residualsdefclip_gradient_bwd(res,g):lo,hi=resreturn(None,None,jnp.clip(g,lo,hi))# return None for lo and hiclip_gradient.defvjp(clip_gradient_fwd,clip_gradient_bwd)
If you use the old way instead of the new way, you’ll get a loud error in anycase where something might go wrong (namely when there’s aTracer passed intoanondiff_argnums argument).
Here’s a case where we actually neednondiff_argnums withcustom_vjp:
fromfunctoolsimportpartialimportjax@partial(jax.custom_vjp,nondiff_argnums=(0,))defskip_app(f,x):returnf(x)defskip_app_fwd(f,x):returnskip_app(f,x),Nonedefskip_app_bwd(f,_,g):return(g,)skip_app.defvjp(skip_app_fwd,skip_app_bwd)
Explanation#
PassingTracers intonondiff_argnums arguments was always buggy. While therewere some cases that worked correctly, others would lead to complex andconfusing error messages.
The essence of the bug was thatnondiff_argnums was implemented in a way thatacted very much like lexical closure. But lexical closure overTracers wasn’tat the time intended to work withcustom_jvp/custom_vjp. Implementingnondiff_argnums that way was a mistake!
PR #4008 fixes all lexical closureissues withcustom_jvp andcustom_vjp. Woohoo! That is, nowcustom_jvpandcustom_vjp functions and rules can close overTracers to our hearts’content. For all non-autodiff transformations, things will Just Work. Forautodiff transformations, we’ll get a clear error message about why we can’tdifferentiate with respect to values over which acustom_jvp orcustom_vjpcloses:
Detected differentiation of a custom_jvp function with respect to a closed-overvalue. That isn’t supported because the custom JVP rule only specifies how todifferentiate the custom_jvp function with respect to explicit input parameters.
Try passing the closed-over value into the custom_jvp function as an argument,and adapting the custom_jvp rule.
In tightening up and robustifyingcustom_jvp andcustom_vjp in this way, wefound that allowingcustom_vjp to acceptTracers in itsnondiff_argnumswould take a significant amount of bookkeeping: we’d need to rewrite the user’sfwd function to return the values as residuals, and rewrite the user’sbwdfunction to accept them as normal residuals (rather than accepting them asspecial leading arguments, as happens withnondiff_argnums). This seems maybemanageable, until you think through how we have to handle arbitrary pytrees!Moreover, that complexity isn’t necessary: if user code treats array-likenon-differentiable arguments just like regular arguments and residuals,everything already works. (Before#4039 JAX might’ve complained aboutinvolving integer-valued inputs and outputs in autodiff, but after#4039 those will just work!)
Unlikecustom_vjp, it was easy to makecustom_jvp work withnondiff_argnums arguments that wereTracers. So these updates only need tohappen withcustom_vjp.
