- Notifications
You must be signed in to change notification settings - Fork2.9k
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
License
jax-ml/jax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Quickstart|Transformations|Install guide|Neural net libraries|Change logs|Reference docs
JAX is a Python library for accelerator-oriented array computation and program transformation,designed for high-performance numerical computing and large-scale machine learning.
With its updated version ofAutograd,JAX can automatically differentiate nativePython and NumPy functions. It can differentiate through loops, branches,recursion, and closures, and it can take derivatives of derivatives ofderivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)viagrad
as well as forward-mode differentiation,and the two can be composed arbitrarily to any order.
What’s new is that JAX usesXLAto compile and run your NumPy programs on GPUs and TPUs. Compilation happensunder the hood by default, with library calls getting just-in-time compiled andexecuted. But JAX also lets you just-in-time compile your own Python functionsinto XLA-optimized kernels using a one-function API,jit
. Compilation and automatic differentiation can becomposed arbitrarily, so you can express sophisticated algorithms and getmaximal performance without leaving Python. You can even program multiple GPUsor TPU cores at once usingpmap
, anddifferentiate through the whole thing.
Dig a little deeper, and you'll see that JAX is really an extensible system forcomposable function transformations. Bothgrad
andjit
are instances of such transformations. Others arevmap
for automatic vectorization andpmap
for single-program multiple-data (SPMD)parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expectsharp edges.Please help by trying it out,reportingbugs, and letting us know what youthink!
importjax.numpyasjnpfromjaximportgrad,jit,vmapdefpredict(params,inputs):forW,binparams:outputs=jnp.dot(inputs,W)+binputs=jnp.tanh(outputs)# inputs to the next layerreturnoutputs# no activation on last layerdefloss(params,inputs,targets):preds=predict(params,inputs)returnjnp.sum((preds-targets)**2)grad_loss=jit(grad(loss))# compiled gradient evaluation functionperex_grads=jit(vmap(grad_loss,in_axes=(None,0,0)))# fast per-example grads
- Quickstart: Colab in the Cloud
- Transformations
- Current gotchas
- Installation
- Neural net libraries
- Citing JAX
- Reference documentation
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.Here are some starter notebooks:
- The basics: NumPy on accelerators,
grad
for differentiation,jit
for compilation, andvmap
for vectorization - Training a Simple Neural Network, with TensorFlow Dataset Data Loading
JAX now runs on Cloud TPUs. To try out the preview, see theCloud TPUColabs.
For a deeper dive into JAX:
- The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX
- Common gotchas and sharp edges
- See thefull list ofnotebooks.
At its core, JAX is an extensible system for transforming numerical functions.Here are four transformations of primary interest:grad
,jit
,vmap
, andpmap
.
JAX has roughly the same API asAutograd.The most popular function isgrad
for reverse-mode gradients:
fromjaximportgradimportjax.numpyasjnpdeftanh(x):# Define a functiony=jnp.exp(-2.0*x)return (1.0-y)/ (1.0+y)grad_tanh=grad(tanh)# Obtain its gradient functionprint(grad_tanh(1.0))# Evaluate it at x = 1.0# prints 0.4199743
You can differentiate to any order withgrad
.
print(grad(grad(grad(tanh)))(1.0))# prints 0.62162673
For more advanced autodiff, you can usejax.vjp
forreverse-mode vector-Jacobian products andjax.jvp
forforward-mode Jacobian-vector products. The two can be composed arbitrarily withone another, and with other JAX transformations. Here's one way to compose thoseto make a function that efficiently computesfull Hessianmatrices:
fromjaximportjit,jacfwd,jacrevdefhessian(fun):returnjit(jacfwd(jacrev(fun)))
As withAutograd, you're free to usedifferentiation with Python control structures:
defabs_val(x):ifx>0:returnxelse:return-xabs_val_grad=grad(abs_val)print(abs_val_grad(1.0))# prints 1.0print(abs_val_grad(-1.0))# prints -1.0 (abs_val is re-evaluated)
See thereference docs on automaticdifferentiationand theJAX AutodiffCookbookfor more.
You can use XLA to compile your functions end-to-end withjit
,used either as an@jit
decorator or as a higher-order function.
importjax.numpyasjnpfromjaximportjitdefslow_f(x):# Element-wise ops see a large benefit from fusionreturnx*x+x*2.0x=jnp.ones((5000,5000))fast_f=jit(slow_f)%timeit-n10-r3fast_f(x)# ~ 4.5 ms / loop on Titan X%timeit-n10-r3slow_f(x)# ~ 14.5 ms / loop (also on GPU via JAX)
You can mixjit
andgrad
and any other JAX transformation however you like.
Usingjit
puts constraints on the kind of Python control flowthe function can use; seethe tutorial onControl Flow and Logical Operators with JITfor more.
vmap
isthe vectorizing map.It has the familiar semantics of mapping a function along array axes, butinstead of keeping the loop on the outside, it pushes the loop down into afunction’s primitive operations for better performance.
Usingvmap
can save you from having to carry around batch dimensions in yourcode. For example, consider this simpleunbatched neural network predictionfunction:
defpredict(params,input_vec):assertinput_vec.ndim==1activations=input_vecforW,binparams:outputs=jnp.dot(W,activations)+b# `activations` on the right-hand side!activations=jnp.tanh(outputs)# inputs to the next layerreturnoutputs# no activation on last layer
We often instead writejnp.dot(activations, W)
to allow for a batch dimension on theleft side ofactivations
, but we’ve written this particular prediction function toapply only to single input vectors. If we wanted to apply this function to abatch of inputs at once, semantically we could just write
fromfunctoolsimportpartialpredictions=jnp.stack(list(map(partial(predict,params),input_batch)))
But pushing one example through the network at a time would be slow! It’s betterto vectorize the computation, so that at every layer we’re doing matrix-matrixmultiplication rather than matrix-vector multiplication.
Thevmap
function does that transformation for us. That is, if we write
fromjaximportvmappredictions=vmap(partial(predict,params))(input_batch)# or, alternativelypredictions=vmap(predict,in_axes=(None,0))(params,input_batch)
then thevmap
function will push the outer loop inside the function, and ourmachine will end up executing matrix-matrix multiplications exactly as if we’ddone the batching by hand.
It’s easy enough to manually batch a simple neural network withoutvmap
, butin other cases manual vectorization can be impractical or impossible. Take theproblem of efficiently computing per-example gradients: that is, for a fixed setof parameters, we want to compute the gradient of our loss function evaluatedseparately at each example in a batch. Withvmap
, it’s easy:
per_example_gradients=vmap(partial(grad(loss),params))(inputs,targets)
Of course,vmap
can be arbitrarily composed withjit
,grad
, and any otherJAX transformation! We usevmap
with both forward- and reverse-mode automaticdifferentiation for fast Jacobian and Hessian matrix calculations injax.jacfwd
,jax.jacrev
, andjax.hessian
.
For parallel programming of multiple accelerators, like multiple GPUs, usepmap
.Withpmap
you write single-program multiple-data (SPMD) programs, includingfast parallel collective communication operations. Applyingpmap
will meanthat the function you write is compiled by XLA (similarly tojit
), thenreplicated and executed in parallel across devices.
Here's an example on an 8-GPU machine:
fromjaximportrandom,pmapimportjax.numpyasjnp# Create 8 random 5000 x 6000 matrices, one per GPUkeys=random.split(random.key(0),8)mats=pmap(lambdakey:random.normal(key, (5000,6000)))(keys)# Run a local matmul on each device in parallel (no data transfer)result=pmap(lambdax:jnp.dot(x,x.T))(mats)# result.shape is (8, 5000, 5000)# Compute the mean on each device in parallel and print the resultprint(pmap(jnp.mean)(result))# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
In addition to expressing pure maps, you can use fastcollective communicationoperationsbetween devices:
fromfunctoolsimportpartialfromjaximportlax@partial(pmap,axis_name='i')defnormalize(x):returnx/lax.psum(x,'i')print(normalize(jnp.arange(4.)))# prints [0. 0.16666667 0.33333334 0.5 ]
You can evennestpmap
functions for moresophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations:
fromjaximportgrad@pmapdeff(x):y=jnp.sin(x)@pmapdefg(z):returnjnp.cos(z)*jnp.tan(y.sum())*jnp.tanh(x).sum()returngrad(lambdaw:jnp.sum(g(w)))(x)print(f(x))# [[ 0. , -0.7170853 ],# [-3.1085174 , -0.4824318 ],# [10.366636 , 13.135289 ],# [ 0.22163185, -0.52112055]]print(grad(lambdax:jnp.sum(f(x)))(x))# [[ -3.2369726, -1.6356447],# [ 4.7572474, 11.606951 ],# [-98.524414 , 42.76499 ],# [ -1.6007166, -1.2568436]]
When reverse-mode differentiating apmap
function (e.g. withgrad
), thebackward pass of the computation is parallelized just like the forward pass.
See theSPMDCookbookand theSPMD MNIST classifier from scratchexamplefor more.
For a more thorough survey of current gotchas, with examples and explanations,we highly recommend reading theGotchasNotebook.Some standouts:
- JAX transformations only work onpure functions, which don't have side-effects and respectreferential transparency (i.e. object identity testing with
is
isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error likeException: Can't lift Traced...
orException: Different traces at same level
. - In-place mutating updates ofarrays, like
x[i] += y
, aren't supported, butthere are functional alternatives. Under ajit
, those functional alternatives will reuse buffers in-place automatically. - Random numbers aredifferent, but forgood reasons.
- If you're looking forconvolutionoperators,they're in the
jax.lax
package. - JAX enforces single-precision (32-bit, e.g.
float32
) values by default, andto enabledouble-precision(64-bit, e.g.float64
) one needs to set thejax_enable_x64
variable atstartup (or set the environment variableJAX_ENABLE_X64=True
).On TPU, JAX uses 32-bit values by default for everythingexcept internaltemporary variables in 'matmul-like' operations, such asjax.numpy.dot
andlax.conv
.Those ops have aprecision
parameter which can be used to approximate 32-bit operationsvia three bfloat16 passes, with a cost of possibly slower runtime.Non-matmul operations on TPU lower to implementations that often emphasize speed overaccuracy, so in practice computations on TPU will be less precise than similarcomputations on other backends. - Some of NumPy's dtype promotion semantics involving a mix of Python scalarsand NumPy types aren't preserved, namely
np.add(1, np.array([2], np.float32)).dtype
isfloat64
rather thanfloat32
. - Some transformations, like
jit
,constrain how you can use Python controlflow.You'll always get loud errors if something goes wrong. You might have to usejit
'sstatic_argnums
parameter,structured control flowprimitiveslikelax.scan
,or just usejit
on smaller subfunctions.
Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | yes | yes | yes | yes | yes | yes |
NVIDIA GPU | yes | yes | no | n/a | no | experimental |
Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
AMD GPU | yes | no | experimental | n/a | no | no |
Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
Intel GPU | experimental | n/a | n/a | n/a | no | no |
Platform | Instructions |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" |
AMD GPU (Linux) | FollowAMD's instructions. |
Mac GPU | FollowApple's instructions. |
Intel GPU | FollowIntel's instructions. |
Seethe documentationfor information on alternative installation strategies. These include compilingfrom source, installing with Docker, using other versions of CUDA, acommunity-supported conda build, and answers to some frequently-asked questions.
Multiple Google research groups at Google DeepMind and Alphabet develop and share librariesfor training neural networks in JAX. If you want a fully featured library for neural networktraining with examples and how-to guides, tryFlax and itsdocumentation site.
Check out theJAX Ecosystem sectionon the JAX documentation site for a list of JAX-based network libraries, which includesOptax for gradient processing andoptimization,chex for reliable code and testing, andEquinox for neural networks.(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talkhere for additional details.)
To cite this repository:
@software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018},}
In the above bibtex entry, names are in alphabetical order, the version numberis intended to be that fromjax/version.py, andthe year corresponds to the project's open-source release.
A nascent version of JAX, supporting only automatic differentiation andcompilation to XLA, was described in apaper that appeared at SysML2018. We're currently working oncovering JAX's ideas and capabilities in a more comprehensive and up-to-datepaper.
For details about the JAX API, see thereference documentation.
For getting started as a JAX developer, see thedeveloper documentation.
About
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more