Writing custom Jaxpr interpreters in JAX
Contents
Writing custom Jaxpr interpreters in JAX#
JAX offers several composable function transformations (jit,grad,vmap,etc.) that enable writing concise, accelerated code.
Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we’ll get composability with all the other transformations for free.
This example uses internal JAX APIs, which may break at any time. Anything not inthe API Documentation should be assumed internal.
importjaximportjax.numpyasjnpfromjaximportjit,grad,vmapfromjaximportrandom
What is JAX doing?#
JAX provides a NumPy-like API for numerical computing which can be used as is, but JAX’s true power comes from composable function transformations. Take thejit function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.
x=random.normal(random.key(0),(5000,5000))deff(w,b,x):returnjnp.tanh(jnp.dot(x,w)+b)fast_f=jit(f)
When we callfast_f, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax’s tracing machinery, you can refer to the“How it works” section in the README.
Jaxpr tracer#
A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language andthus Jaxprs are a useful intermediate representationfor function transformation.
To get a first look at Jaxprs, consider themake_jaxpr transformation.make_jaxpr is essentially a “pretty-printing” transformation:it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.make_jaxpr is useful for debugging and introspection.Let’s use it to look at how some example Jaxprs are structured.
defexamine_jaxpr(closed_jaxpr):jaxpr=closed_jaxpr.jaxprprint("invars:",jaxpr.invars)print("outvars:",jaxpr.outvars)print("constvars:",jaxpr.constvars)foreqninjaxpr.eqns:print("equation:",eqn.invars,eqn.primitive,eqn.outvars,eqn.params)print()print("jaxpr:",jaxpr)deffoo(x):returnx+1print("foo")print("=====")examine_jaxpr(jax.make_jaxpr(foo)(5))print()defbar(w,b,x):returnjnp.dot(w,x)+b+jnp.ones(5),xprint("bar")print("=====")examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5,10)),jnp.ones(5),jnp.ones(10)))
foo=====invars: [Var(id=139132117288768):int32[]]outvars: [Var(id=139132117298176):int32[]]constvars: []equation: [Var(id=139132117288768):int32[], Literal(TypedInt(1, dtype=int32))] add [Var(id=139132117298176):int32[]] {}jaxpr: {lambda; a:i32[].let b:i32[] = add a 1:i32[]in(b,) }bar=====invars: [Var(id=139132111268352):float32[5,10], Var(id=139132117000128):float32[5], Var(id=139132111268480):float32[10]]outvars: [Var(id=139132111374016):float32[5], Var(id=139132111268480):float32[10]]constvars: []equation: [Var(id=139132111268352):float32[5,10], Var(id=139132111268480):float32[10]] dot_general [Var(id=139132111271744):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32'), 'out_sharding': None}equation: [Var(id=139132111271744):float32[5], Var(id=139132117000128):float32[5]] add [Var(id=139132111373376):float32[5]] {}equation: [Literal(TypedNdArray(1., dtype=float32))] broadcast_in_dim [Var(id=139132111373888):float32[5]] {'shape': (5,), 'broadcast_dimensions': (), 'sharding': None}equation: [Var(id=139132111373376):float32[5], Var(id=139132111373888):float32[5]] add [Var(id=139132111374016):float32[5]] {}jaxpr: {lambda; a:f32[5,10] b:f32[5] c:f32[10].letd:f32[5] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] a c e:f32[5] = add d b f:f32[5] = broadcast_in_dim[ broadcast_dimensions=() shape=(5,) sharding=None ] 1.0:f32[] g:f32[5] = add e fin(g, c) }jaxpr.invars- theinvarsof a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.jaxpr.outvars- theoutvarsof a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.jaxpr.constvars- theconstvarsare a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we’ll go over these in more detail later).jaxpr.eqns- a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and aprimitive, which is used to evaluate inputs to produce outputs. Each equation also has aparams, a dictionary of parameters.
Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We’ll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
Why are Jaxprs useful?#
Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python.
Your first interpreter:invert#
Let’s try to implement a simple function “inverter”, which takes in the output of the original function and returns the inputs that produced those outputs. For now, let’s focus on simple, unary functions which are composed of other invertible unary functions.
Goal:
deff(x):returnjnp.exp(jnp.tanh(x))f_inv=inverse(f)assertjnp.allclose(f_inv(f(1.0)),1.0)
The way we’ll implement this is by (1) tracingf into a Jaxpr, then (2) interpreting the Jaxprbackwards. While interpreting the Jaxpr backwards, for each equation we’ll look up the primitive’s inverse in a table and apply it.
1. Tracing a function#
Let’s usemake_jaxpr to trace a function into a Jaxpr.
# Importing Jax functions useful for tracing/interpreting.fromfunctoolsimportwrapsfromjaximportlaxfromjax.extendimportcorefromjax._src.utilimportsafe_map
jax.make_jaxpr returns aclosed Jaxpr, which is a Jaxpr that has been bundled withthe constants (literals) from the trace.
deff(x):returnjnp.exp(jnp.tanh(x))closed_jaxpr=jax.make_jaxpr(f)(jnp.ones(5))print(closed_jaxpr.jaxpr)print(closed_jaxpr.literals)
{lambda; a:f32[5].let b:f32[5] = tanh a; c:f32[5] = exp bin(c,) }[]2. Evaluating a Jaxpr#
Before we write a custom Jaxpr interpreter, let’s first implement the “default” interpreter,eval_jaxpr, which evaluates the Jaxpr as-is, computing the same values that the original, un-transformed Python function would.
To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr.
defeval_jaxpr(jaxpr,consts,*args):# Mapping from variable -> valueenv={}defread(var):# Literals are values baked into the Jaxpriftype(var)iscore.Literal:returnvar.valreturnenv[var]defwrite(var,val):env[var]=val# Bind args and consts to environmentsafe_map(write,jaxpr.invars,args)safe_map(write,jaxpr.constvars,consts)# Loop through equations and evaluate primitives using `bind`foreqninjaxpr.eqns:# Read inputs to equation from environmentinvals=safe_map(read,eqn.invars)# `bind` is how a primitive is calledoutvals=eqn.primitive.bind(*invals,**eqn.params)# Primitives may return multiple outputs or notifnoteqn.primitive.multiple_results:outvals=[outvals]# Write the results of the primitive into the environmentsafe_map(write,eqn.outvars,outvals)# Read the final result of the Jaxpr from the environmentreturnsafe_map(read,jaxpr.outvars)
closed_jaxpr=jax.make_jaxpr(f)(jnp.ones(5))eval_jaxpr(closed_jaxpr.jaxpr,closed_jaxpr.literals,jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
Notice thateval_jaxpr will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle higher-order primitives (likejit andpmap), which we will not cover in this guide. You can refer tocore.eval_jaxpr (link) to see the edge cases that this interpreter does not cover.
Custominverse Jaxpr interpreter#
Aninverse interpreter doesn’t look too different fromeval_jaxpr. We’ll first set up the registry which will map primitives to their inverses. We’ll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the “transpose” interpreter used in reverse-mode autodifferentiationfound here.
inverse_registry={}
We’ll now register inverses for some of the primitives. By convention, primitives in Jax end in_p and a lot of the popular ones live inlax.
inverse_registry[lax.exp_p]=jnp.loginverse_registry[lax.tanh_p]=jnp.arctanh
inverse will first trace the function, then custom-interpret the Jaxpr. Let’s set up a simple skeleton.
definverse(fun):@wraps(fun)defwrapped(*args,**kwargs):# Since we assume unary functions, we won't worry about flattening and# unflattening arguments.closed_jaxpr=jax.make_jaxpr(fun)(*args,**kwargs)out=inverse_jaxpr(closed_jaxpr.jaxpr,closed_jaxpr.literals,*args)returnout[0]returnwrapped
Now we just need to defineinverse_jaxpr, which will walk through the Jaxpr backward and invert primitives when it can.
definverse_jaxpr(jaxpr,consts,*args):env={}defread(var):iftype(var)iscore.Literal:returnvar.valreturnenv[var]defwrite(var,val):env[var]=val# Args now correspond to Jaxpr outvarssafe_map(write,jaxpr.outvars,args)safe_map(write,jaxpr.constvars,consts)# Looping backwardforeqninjaxpr.eqns[::-1]:# outvars are now invarsinvals=safe_map(read,eqn.outvars)ifeqn.primitivenotininverse_registry:raiseNotImplementedError(f"{eqn.primitive} does not have registered inverse.")# Assuming a unary functionoutval=inverse_registry[eqn.primitive](*invals)safe_map(write,eqn.invars,[outval])returnsafe_map(read,jaxpr.invars)
That’s it!
deff(x):returnjnp.exp(jnp.tanh(x))f_inv=inverse(f)assertjnp.allclose(f_inv(f(1.0)),1.0)
Importantly, you can trace through a Jaxpr interpreter.
jax.make_jaxpr(inverse(f))(f(1.))
{lambda; a:f32[].let b:f32[] = log a; c:f32[] = atanh bin(c,) }That’s all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can usejit,vmap, andgrad withinverse!
jit(vmap(grad(inverse(f))))((jnp.arange(5)+1.)/5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32, weak_type=True)
Exercises for the reader#
Handle primitives with multiple arguments where inputs are partially known, for example
lax.add_p,lax.mul_p.Handle
xla_callandxla_pmapprimitives, which will not work with botheval_jaxprandinverse_jaxpras written.
