- Notifications
You must be signed in to change notification settings - Fork4
Multiple dispatch over abstract array types in JAX.
License
patrick-kidger/quax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul(W + AB)v
into the more-efficientWv + ABv
.
Applications include:
- LoRA weight matrices
- symbolic zeros
- arrays with named dimensions
- structured (e.g. tridiagonal) matrices
- sparse arrays
- quantised arrays
- arrays with physical units attached
- etc! (See the built-in
quax.examples
library for most of the above!)
This works via a custom JAX transform. Take an existing JAX program, wrap it in aquax.quaxify
, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!
(Just like howjax.vmap
takes a program, but reinterprets each operation as its batched version, so to willquax.quaxify
take a program and reinterpret each operation according to what array-ish types are passed.)
pip install quax
Available athttps://docs.kidger.site/quax.
This example demonstrates everything you need to use the built-inquax.examples.lora
library.
importequinoxaseqximportjax.randomasjrimportquaximportquax.examples.loraaslora## Start off with any JAX program: here, the forward pass through a linear layer.#key1,key2,key3=jr.split(jr.PRNGKey(0),3)linear=eqx.nn.Linear(10,12,key=key1)vector=jr.normal(key2, (10,))defrun(model,x):returnmodel(x)run(linear,vector)# can call this as normal## Now let's Lora-ify it.## Step 1: make the weight be a LoraArray.lora_weight=lora.LoraArray(linear.weight,rank=2,key=key3)lora_linear=eqx.tree_at(lambdal:l.weight,linear,lora_weight)# Step 2: quaxify and call the original function. The transform will call the# original function, whilst looking up any multiple dispatch rules registered.# (In this case for doing matmuls against LoraArrays.)quax.quaxify(run)(lora_linear,vector)# Appendix: Quax includes a helper to automatically apply Step 1 to all# `eqx.nn.Linear` layers in a model.lora_linear=lora.loraify(linear,rank=2,key=key3)
Right now, the following are not supported:
jax.lax.scan_p
jax.custom_vjp
It should be fairly straightforward to add support for these; open an issue or pull request. (We've already gotjax.custom_jvp
,jax.lax.cond_p
, andjax.lax.while_p
. :) )
Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)
Built on Quax
Quaxed: a namespace of already-wrappedquaxify(jnp.foo)
operations.
unxt: Unitful Quantities.
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
Significantly inspired byhttps://github.com/davisyoshida/qax,https://github.com/stanford-crfm/levanter, andjax.experimental.sparse
.
About
Multiple dispatch over abstract array types in JAX.