- Notifications
You must be signed in to change notification settings - Fork3.1k
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
License
jax-ml/jax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Transformations|Scaling|Install guide|Change logs|Reference docs
JAX is a Python library for accelerator-oriented array computation and program transformation,designed for high-performance numerical computing and large-scale machine learning.
JAX can automatically differentiate nativePython and NumPy functions. It can differentiate through loops, branches,recursion, and closures, and it can take derivatives of derivatives ofderivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)viajax.grad
as well as forward-mode differentiation,and the two can be composed arbitrarily to any order.
JAX usesXLAto compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators.You can compile your own pure functions withjax.jit
.Compilation and automatic differentiation can be composed arbitrarily.
Dig a little deeper, and you'll see that JAX is really an extensible system forcomposable function transformations atscale.
This is a research project, not an official Google product. Expectsharp edges.Please help by trying it out,reporting bugs,and letting us know what you think!
importjaximportjax.numpyasjnpdefpredict(params,inputs):forW,binparams:outputs=jnp.dot(inputs,W)+binputs=jnp.tanh(outputs)# inputs to the next layerreturnoutputs# no activation on last layerdefloss(params,inputs,targets):preds=predict(params,inputs)returnjnp.sum((preds-targets)**2)grad_loss=jax.jit(jax.grad(loss))# compiled gradient evaluation functionperex_grads=jax.jit(jax.vmap(grad_loss,in_axes=(None,0,0)))# fast per-example grads
- Transformations
- Scaling
- Current gotchas
- Installation
- Neural net libraries
- Citing JAX
- Reference documentation
At its core, JAX is an extensible system for transforming numerical functions.Here are three:jax.grad
,jax.jit
, andjax.vmap
.
Usejax.grad
to efficiently compute reverse-mode gradients:
importjaximportjax.numpyasjnpdeftanh(x):y=jnp.exp(-2.0*x)return (1.0-y)/ (1.0+y)grad_tanh=jax.grad(tanh)print(grad_tanh(1.0))# prints 0.4199743
You can differentiate to any order withgrad
:
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))# prints 0.62162673
You're free to use differentiation with Python control flow:
defabs_val(x):ifx>0:returnxelse:return-xabs_val_grad=jax.grad(abs_val)print(abs_val_grad(1.0))# prints 1.0print(abs_val_grad(-1.0))# prints -1.0 (abs_val is re-evaluated)
See theJAX AutodiffCookbookand thereference docs on automaticdifferentiationfor more.
Use XLA to compile your functions end-to-end withjit
,used either as an@jit
decorator or as a higher-order function.
importjaximportjax.numpyasjnpdefslow_f(x):# Element-wise ops see a large benefit from fusionreturnx*x+x*2.0x=jnp.ones((5000,5000))fast_f=jax.jit(slow_f)%timeit-n10-r3fast_f(x)%timeit-n10-r3slow_f(x)
Usingjax.jit
constrains the kind of Python control flowthe function can use; seethe tutorial onControl Flow and Logical Operators with JITfor more.
vmap
mapsa function along array axes.But instead of just looping over function applications, it pushes the loop downonto the function’s primitive operations, e.g. turning matrix-vector multiplies intomatrix-matrix multiplies for better performance.
Usingvmap
can save you from having to carry around batch dimensions in yourcode:
importjaximportjax.numpyasjnpdefl1_distance(x,y):assertx.ndim==y.ndim==1# only works on 1D inputsreturnjnp.sum(jnp.abs(x-y))defpairwise_distances(dist1D,xs):returnjax.vmap(jax.vmap(dist1D, (0,None)), (None,0))(xs,xs)xs=jax.random.normal(jax.random.key(0), (100,3))dists=pairwise_distances(l1_distance,xs)dists.shape# (100, 100)
By composingjax.vmap
withjax.grad
andjax.jit
, we can get efficientJacobian matrices, or per-example gradients:
per_example_grads=jax.jit(jax.vmap(jax.grad(loss),in_axes=(None,0,0)))
To scale your computations across thousands of devices, you can use anycomposition of these:
- Compiler-based automatic parallelizationwhere you program as if using a single global machine, and the compiler chooseshow to shard data and partition computation (with some user-provided constraints);
- Explicit sharding and automatic partitioningwhere you still have a global view but data shardings areexplicit in JAX types, inspectable using
jax.typeof
; - Manual per-device programmingwhere you have a per-device view of dataand computation, and can communicate with explicit collectives.
Mode | View? | Explicit sharding? | Explicit Collectives? |
---|---|---|---|
Auto | Global | ❌ | ❌ |
Explicit | Global | ✅ | ❌ |
Manual | Per-device | ✅ | ✅ |
fromjax.shardingimportset_mesh,AxisType,PartitionSpecasPmesh=jax.make_mesh((8,), ('data',),axis_types=(AxisType.Explicit,))set_mesh(mesh)# parameters are sharded for FSDP:forW,binparams:print(f'{jax.typeof(W)}')# f32[512@data,512]print(f'{jax.typeof(b)}')# f32[512]# shard data for batch parallelism:inputs,targets=jax.device_put((inputs,targets),P('data'))# evaluate gradients, automatically parallelized!gradfun=jax.jit(jax.grad(loss))param_grads=gradfun(params, (inputs,targets))
See thetutorial andadvanced guides for more.
See theGotchasNotebook.
Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|
CPU | yes | yes | yes | yes | yes |
NVIDIA GPU | yes | yes | n/a | no | experimental |
Google TPU | yes | n/a | n/a | n/a | n/a |
AMD GPU | yes | no | n/a | no | no |
Apple GPU | n/a | no | experimental | n/a | n/a |
Intel GPU | experimental | n/a | n/a | no | no |
Platform | Instructions |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" |
AMD GPU (Linux) | FollowAMD's instructions. |
Mac GPU | FollowApple's instructions. |
Intel GPU | FollowIntel's instructions. |
Seethe documentationfor information on alternative installation strategies. These include compilingfrom source, installing with Docker, using other versions of CUDA, acommunity-supported conda build, and answers to some frequently-asked questions.
To cite this repository:
@software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018},}
In the above bibtex entry, names are in alphabetical order, the version numberis intended to be that fromjax/version.py, andthe year corresponds to the project's open-source release.
A nascent version of JAX, supporting only automatic differentiation andcompilation to XLA, was described in apaper that appeared at SysML2018. We're currently working oncovering JAX's ideas and capabilities in a more comprehensive and up-to-datepaper.
For details about the JAX API, see thereference documentation.
For getting started as a JAX developer, see thedeveloper documentation.
About
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.