jax.extend: a module for extensions
Contents
jax.extend: a module for extensions#
@froystig,@sharadmv,@jakevdp,@yashk2810
May 2023
importjax.extendasjex
Several projects depend on JAX’s codebase internals, often to use itscore machinery (e.g. to write atransformation over its IR)or to extend it (e.g. todefine new primitives).Two challenges for these dependencies are (a) that our internalsaren’t all solidly designed for external use, and (b) thatcircumventing JAX’s public API isunsupported.In other words, our internals are often used like a library, but areneither structured nor updated like one.
This proposal considersintroducing ajax.extend module thatdefines a library view of some of JAX’s internal components. We wouldtreat this as a second-tier API, still guaranteeing essentiallynocompatibility policy, but hopefully makingit easier to spot changes when they happen.
The audience forjax.extend includes JAX-adjacent Python librarieslikeOryx,jax-triton, and many others,as well as projects experimenting with function transformations,autodiff systems, compiler frontends for numerical programming, etc.
This note gives an overview of howjax.extend might look, now andeventually. It doesn’t lay things out in great detail, insteadproposing that we beginiteratively developingthe module.
Note thatjax.extend differs fromjax.experimental, which is astaging ground for new features and ideas in progress. Typically, workinjax.experimental eventually makes into another JAX module or isremoved altogether.
No compatibility policy#
To keep development overhead low,jax.extend would not follow thepublicAPI compatibilitypolicy. It would promise no deprecation windows nor backwardscompatibility between releases. Every release may break existingcallers without simple recourse (e.g. without a flag reintroducingprior behavior). We would rely on thechangelogto call out such changes.
Callers ofjax.extend that need to upgrade their code regularlyalongside JAX releases might find it useful to pin JAX versions as anintermediate step between releases. This is a common habit amongprojects that rely on JAX’s internals today. The difference is that itwould now come with the help of changelog announcements and betterintentions regarding library design and naming.
Iterative development#
Having no compatibility policy makes it easier to get started onimplementation: on day one, we can move a handful of symbols over frominternal packages such asjax._src and today’sjax.core andjax.interpreters. Then we can iterate to improve things from there.
Possible module overview#
We can imagine that eventuallyjax.extend would include thefollowing modules:
core– primitives, the Jaxpr IR, etc.interpreters– core transformations (e.g. autodiff, batching)and lowerings.random– random bit generation, key splitting and folding, keyarrays.sharding– extra functionality around distributed arrays.
We might also have other symbols in the module at first, such asjex.api_util, as we work to remove or replace them. Others will bedecided in time. For instance,jex.lib could offer an entry point tojaxlib (and would do so in the immediate term), but it’s not clearwhether we want to keep it for long.
Some preliminary thoughts on what each of these might comprise follow.
jax.extend.core#
This should enable callers at least to define new JAX primitives andto process the Jaxpr IR (the output ofjax.make_jaxpr(...)). Supporting this might involve providing:
Access to existing core system primitives, such as today’s
jax._src.lax.add_p.Access to IR types, such as the current
jax._src.core.ShapedArray.Functions for checking and pretty-printing jaxprs.
Functions for building jaxprs explicitly, rather than by stagingPython functions via
jax.make_jaxpr(or not!).
At initialization, this module will contain many more symbols thanwhat’s needed to define primitives and rules, including various namesused in setting up“final-style transformations”,such as the currentjax._src.core.Trace andTracer classes. We canrevisit whetherjex.core should also support final-style extensionsalongside initial style approaches, and whether it can do so by a morenarrow API than exposingTrace andTracer entirely.Oryx might help guide these decisions.
We can also consider relocatingmake_jaxpr itself tojex.core.
jax.extend.interpreters#
This module would provide a means of registering varioustransformation rules for primitives—defining their behaviorunder AD, batching, lowering, etc.
It would initially reflectjax._src.interpreters in providingthe modulesad,batching,partial_eval (for staging Python toJaxpr, and for linearization in AD),mlir,pxla, andxla. Thefirst three might be replaceable by a single primitive extension APIinjex.core. The latter three, used for lowering, could besimplified into one module, maybe.
Today, to write transformation rules, e.g. for AD and batching,callers may need symbols relating to tracers, e.g.JVPTracer andBatchTracer. This may be avoidable later on, and allow us to removetracer types fromjex.
This module plusjex.core ought to suffice for replicating today’scustom primitive tutorials (e.g.oursanddfm’s).For instance, defining a primitive and its behavior underjax.jitwould be possible as follows (in the immediate term):
fromjax.extendimportcore# Previously: from jax import corefromjax.extend.interpretersimportmlir# ... and similarlymul_add_p=core.Primitive('mul_add')mul_add_p.def_impl(lambdax,y,z:x*y+z)@mul_add_p.def_abstract_evaldefmul_add_abstract(x_sa,y_sa,z_sa):returncore.ShapedArray(x_sa.shape,x_sa.dtype)defmul_add_mlir(ctx,xc,yc,zc):add=mlir.hlo.AddOpmul=mlir.hlo.MulOpreturnadd(mul(xc,yc),zc).resultsmlir.register_lowering(mul_add_p,mul_add_mlir)importjaxprint(mul_add_p.bind(2,3,4))# -> 10print(jax.jit(mul_add_p.bind)(2,3,4))# -> Array(10, dtype=int32)
jax.extend.random#
This module could expose our mechanism for defining new RNGimplementations, and functions for working with PRNG key internals(see issue#9263),such as the currentjax._src.prng.random_wrap andrandom_unwrap.
It could also expose the keyed hash functions that underlie thebuilt-in RNG implementations, such asjax._src.prng.threefry_2x32.
jax.extend.sharding#
This module could expose low-level utilities for sharding distributedarrays.
We have only one item in mind for now. The XLA compiler’sarray sharding format is more expressive thanthose provided byJAX. We couldprovide this asjex.sharding.XlaOpShardingProto, corresponding totoday’sjax._src.lib.xla_client.OpSharding internally.
