Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork5
Multiple dispatch over abstract array types in JAX.
License
patrick-kidger/quax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul(W + AB)v
into the more-efficientWv + ABv
.
Applications include:
- LoRA weight matrices
- symbolic zeros
- arrays with named dimensions
- structured (e.g. tridiagonal) matrices
- sparse arrays
- quantised arrays
- arrays with physical units attached
- etc! (See the built-in
quax.examples
library for most of the above!)
This works via a custom JAX transform. Take an existing JAX program, wrap it in aquax.quaxify
, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!
(Just like howjax.vmap
takes a program, but reinterprets each operation as its batched version, so to willquax.quaxify
take a program and reinterpret each operation according to what array-ish types are passed.)
pip install quax
Available athttps://docs.kidger.site/quax.
This example demonstrates everything you need to use the built-inquax.examples.lora
library.
importequinoxaseqximportjax.randomasjrimportquaximportquax.examples.loraaslora## Start off with any JAX program: here, the forward pass through a linear layer.#key1,key2,key3=jr.split(jr.PRNGKey(0),3)linear=eqx.nn.Linear(10,12,key=key1)vector=jr.normal(key2, (10,))defrun(model,x):returnmodel(x)run(linear,vector)# can call this as normal## Now let's Lora-ify it.## Step 1: make the weight be a LoraArray.lora_weight=lora.LoraArray(linear.weight,rank=2,key=key3)lora_linear=eqx.tree_at(lambdal:l.weight,linear,lora_weight)# Step 2: quaxify and call the original function. The transform will call the# original function, whilst looking up any multiple dispatch rules registered.# (In this case for doing matmuls against LoraArrays.)quax.quaxify(run)(lora_linear,vector)# Appendix: Quax includes a helper to automatically apply Step 1 to all# `eqx.nn.Linear` layers in a model.lora_linear=lora.loraify(linear,rank=2,key=key3)
Right now, the following are not supported:
jax.lax.scan_p
jax.custom_vjp
It should be fairly straightforward to add support for these; open an issue or pull request. (We've already gotjax.custom_jvp
,jax.lax.cond_p
, andjax.lax.while_p
. :) )
Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)
Built on Quax
Quaxed: a namespace of already-wrappedquaxify(jnp.foo)
operations.
unxt: Unitful Quantities.
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
Significantly inspired byhttps://github.com/davisyoshida/qax,https://github.com/stanford-crfm/levanter, andjax.experimental.sparse
.
About
Multiple dispatch over abstract array types in JAX.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Sponsor this project
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Contributors5
Uh oh!
There was an error while loading.Please reload this page.