Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
/jaxPublic

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

License

NotificationsYou must be signed in to change notification settings

jax-ml/jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

Transformable numerical computing at scale

Continuous integrationPyPI version

Quickstart|Transformations|Install guide|Neural net libraries|Change logs|Reference docs

What is JAX?

JAX is a Python library for accelerator-oriented array computation and program transformation,designed for high-performance numerical computing and large-scale machine learning.

With its updated version ofAutograd,JAX can automatically differentiate nativePython and NumPy functions. It can differentiate through loops, branches,recursion, and closures, and it can take derivatives of derivatives ofderivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)viagrad as well as forward-mode differentiation,and the two can be composed arbitrarily to any order.

What’s new is that JAX usesXLAto compile and run your NumPy programs on GPUs and TPUs. Compilation happensunder the hood by default, with library calls getting just-in-time compiled andexecuted. But JAX also lets you just-in-time compile your own Python functionsinto XLA-optimized kernels using a one-function API,jit. Compilation and automatic differentiation can becomposed arbitrarily, so you can express sophisticated algorithms and getmaximal performance without leaving Python. You can even program multiple GPUsor TPU cores at once usingpmap, anddifferentiate through the whole thing.

Dig a little deeper, and you'll see that JAX is really an extensible system forcomposable function transformations. Bothgrad andjitare instances of such transformations. Others arevmap for automatic vectorization andpmap for single-program multiple-data (SPMD)parallel programming of multiple accelerators, with more to come.

This is a research project, not an official Google product. Expectsharp edges.Please help by trying it out,reportingbugs, and letting us know what youthink!

importjax.numpyasjnpfromjaximportgrad,jit,vmapdefpredict(params,inputs):forW,binparams:outputs=jnp.dot(inputs,W)+binputs=jnp.tanh(outputs)# inputs to the next layerreturnoutputs# no activation on last layerdefloss(params,inputs,targets):preds=predict(params,inputs)returnjnp.sum((preds-targets)**2)grad_loss=jit(grad(loss))# compiled gradient evaluation functionperex_grads=jit(vmap(grad_loss,in_axes=(None,0,0)))# fast per-example grads

Contents

Quickstart: Colab in the Cloud

Jump right in using a notebook in your browser, connected to a Google Cloud GPU.Here are some starter notebooks:

JAX now runs on Cloud TPUs. To try out the preview, see theCloud TPUColabs.

For a deeper dive into JAX:

Transformations

At its core, JAX is an extensible system for transforming numerical functions.Here are four transformations of primary interest:grad,jit,vmap, andpmap.

Automatic differentiation withgrad

JAX has roughly the same API asAutograd.The most popular function isgradfor reverse-mode gradients:

fromjaximportgradimportjax.numpyasjnpdeftanh(x):# Define a functiony=jnp.exp(-2.0*x)return (1.0-y)/ (1.0+y)grad_tanh=grad(tanh)# Obtain its gradient functionprint(grad_tanh(1.0))# Evaluate it at x = 1.0# prints 0.4199743

You can differentiate to any order withgrad.

print(grad(grad(grad(tanh)))(1.0))# prints 0.62162673

For more advanced autodiff, you can usejax.vjp forreverse-mode vector-Jacobian products andjax.jvp forforward-mode Jacobian-vector products. The two can be composed arbitrarily withone another, and with other JAX transformations. Here's one way to compose thoseto make a function that efficiently computesfull Hessianmatrices:

fromjaximportjit,jacfwd,jacrevdefhessian(fun):returnjit(jacfwd(jacrev(fun)))

As withAutograd, you're free to usedifferentiation with Python control structures:

defabs_val(x):ifx>0:returnxelse:return-xabs_val_grad=grad(abs_val)print(abs_val_grad(1.0))# prints 1.0print(abs_val_grad(-1.0))# prints -1.0 (abs_val is re-evaluated)

See thereference docs on automaticdifferentiationand theJAX AutodiffCookbookfor more.

Compilation withjit

You can use XLA to compile your functions end-to-end withjit,used either as an@jit decorator or as a higher-order function.

importjax.numpyasjnpfromjaximportjitdefslow_f(x):# Element-wise ops see a large benefit from fusionreturnx*x+x*2.0x=jnp.ones((5000,5000))fast_f=jit(slow_f)%timeit-n10-r3fast_f(x)# ~ 4.5 ms / loop on Titan X%timeit-n10-r3slow_f(x)# ~ 14.5 ms / loop (also on GPU via JAX)

You can mixjit andgrad and any other JAX transformation however you like.

Usingjit puts constraints on the kind of Python control flowthe function can use; seethe tutorial onControl Flow and Logical Operators with JITfor more.

Auto-vectorization withvmap

vmap isthe vectorizing map.It has the familiar semantics of mapping a function along array axes, butinstead of keeping the loop on the outside, it pushes the loop down into afunction’s primitive operations for better performance.

Usingvmap can save you from having to carry around batch dimensions in yourcode. For example, consider this simpleunbatched neural network predictionfunction:

defpredict(params,input_vec):assertinput_vec.ndim==1activations=input_vecforW,binparams:outputs=jnp.dot(W,activations)+b# `activations` on the right-hand side!activations=jnp.tanh(outputs)# inputs to the next layerreturnoutputs# no activation on last layer

We often instead writejnp.dot(activations, W) to allow for a batch dimension on theleft side ofactivations, but we’ve written this particular prediction function toapply only to single input vectors. If we wanted to apply this function to abatch of inputs at once, semantically we could just write

fromfunctoolsimportpartialpredictions=jnp.stack(list(map(partial(predict,params),input_batch)))

But pushing one example through the network at a time would be slow! It’s betterto vectorize the computation, so that at every layer we’re doing matrix-matrixmultiplication rather than matrix-vector multiplication.

Thevmap function does that transformation for us. That is, if we write

fromjaximportvmappredictions=vmap(partial(predict,params))(input_batch)# or, alternativelypredictions=vmap(predict,in_axes=(None,0))(params,input_batch)

then thevmap function will push the outer loop inside the function, and ourmachine will end up executing matrix-matrix multiplications exactly as if we’ddone the batching by hand.

It’s easy enough to manually batch a simple neural network withoutvmap, butin other cases manual vectorization can be impractical or impossible. Take theproblem of efficiently computing per-example gradients: that is, for a fixed setof parameters, we want to compute the gradient of our loss function evaluatedseparately at each example in a batch. Withvmap, it’s easy:

per_example_gradients=vmap(partial(grad(loss),params))(inputs,targets)

Of course,vmap can be arbitrarily composed withjit,grad, and any otherJAX transformation! We usevmap with both forward- and reverse-mode automaticdifferentiation for fast Jacobian and Hessian matrix calculations injax.jacfwd,jax.jacrev, andjax.hessian.

SPMD programming withpmap

For parallel programming of multiple accelerators, like multiple GPUs, usepmap.Withpmap you write single-program multiple-data (SPMD) programs, includingfast parallel collective communication operations. Applyingpmap will meanthat the function you write is compiled by XLA (similarly tojit), thenreplicated and executed in parallel across devices.

Here's an example on an 8-GPU machine:

fromjaximportrandom,pmapimportjax.numpyasjnp# Create 8 random 5000 x 6000 matrices, one per GPUkeys=random.split(random.key(0),8)mats=pmap(lambdakey:random.normal(key, (5000,6000)))(keys)# Run a local matmul on each device in parallel (no data transfer)result=pmap(lambdax:jnp.dot(x,x.T))(mats)# result.shape is (8, 5000, 5000)# Compute the mean on each device in parallel and print the resultprint(pmap(jnp.mean)(result))# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

In addition to expressing pure maps, you can use fastcollective communicationoperationsbetween devices:

fromfunctoolsimportpartialfromjaximportlax@partial(pmap,axis_name='i')defnormalize(x):returnx/lax.psum(x,'i')print(normalize(jnp.arange(4.)))# prints [0.         0.16666667 0.33333334 0.5       ]

You can evennestpmap functions for moresophisticated communication patterns.

It all composes, so you're free to differentiate through parallel computations:

fromjaximportgrad@pmapdeff(x):y=jnp.sin(x)@pmapdefg(z):returnjnp.cos(z)*jnp.tan(y.sum())*jnp.tanh(x).sum()returngrad(lambdaw:jnp.sum(g(w)))(x)print(f(x))# [[ 0.        , -0.7170853 ],#  [-3.1085174 , -0.4824318 ],#  [10.366636  , 13.135289  ],#  [ 0.22163185, -0.52112055]]print(grad(lambdax:jnp.sum(f(x)))(x))# [[ -3.2369726,  -1.6356447],#  [  4.7572474,  11.606951 ],#  [-98.524414 ,  42.76499  ],#  [ -1.6007166,  -1.2568436]]

When reverse-mode differentiating apmap function (e.g. withgrad), thebackward pass of the computation is parallelized just like the forward pass.

See theSPMDCookbookand theSPMD MNIST classifier from scratchexamplefor more.

Current gotchas

For a more thorough survey of current gotchas, with examples and explanations,we highly recommend reading theGotchasNotebook.Some standouts:

  1. JAX transformations only work onpure functions, which don't have side-effects and respectreferential transparency (i.e. object identity testing withis isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error likeException: Can't lift Traced... orException: Different traces at same level.
  2. In-place mutating updates ofarrays, likex[i] += y, aren't supported, butthere are functional alternatives. Under ajit, those functional alternatives will reuse buffers in-place automatically.
  3. Random numbers aredifferent, but forgood reasons.
  4. If you're looking forconvolutionoperators,they're in thejax.lax package.
  5. JAX enforces single-precision (32-bit, e.g.float32) values by default, andto enabledouble-precision(64-bit, e.g.float64) one needs to set thejax_enable_x64 variable atstartup (or set the environment variableJAX_ENABLE_X64=True).On TPU, JAX uses 32-bit values by default for everythingexcept internaltemporary variables in 'matmul-like' operations, such asjax.numpy.dot andlax.conv.Those ops have aprecision parameter which can be used to approximate 32-bit operationsvia three bfloat16 passes, with a cost of possibly slower runtime.Non-matmul operations on TPU lower to implementations that often emphasize speed overaccuracy, so in practice computations on TPU will be less precise than similarcomputations on other backends.
  6. Some of NumPy's dtype promotion semantics involving a mix of Python scalarsand NumPy types aren't preserved, namelynp.add(1, np.array([2], np.float32)).dtype isfloat64 rather thanfloat32.
  7. Some transformations, likejit,constrain how you can use Python controlflow.You'll always get loud errors if something goes wrong. You might have to usejit'sstatic_argnumsparameter,structured control flowprimitiveslikelax.scan,or just usejit on smaller subfunctions.

Installation

Supported platforms

Linux x86_64Linux aarch64Mac x86_64Mac aarch64Windows x86_64Windows WSL2 x86_64
CPUyesyesyesyesyesyes
NVIDIA GPUyesyesnon/anoexperimental
Google TPUyesn/an/an/an/an/a
AMD GPUyesnoexperimentaln/anono
Apple GPUn/anon/aexperimentaln/an/a
Intel GPUexperimentaln/an/an/anono

Instructions

PlatformInstructions
CPUpip install -U jax
NVIDIA GPUpip install -U "jax[cuda12]"
Google TPUpip install -U "jax[tpu]"
AMD GPU (Linux)FollowAMD's instructions.
Mac GPUFollowApple's instructions.
Intel GPUFollowIntel's instructions.

Seethe documentationfor information on alternative installation strategies. These include compilingfrom source, installing with Docker, using other versions of CUDA, acommunity-supported conda build, and answers to some frequently-asked questions.

Neural network libraries

Multiple Google research groups at Google DeepMind and Alphabet develop and share librariesfor training neural networks in JAX. If you want a fully featured library for neural networktraining with examples and how-to guides, tryFlax and itsdocumentation site.

Check out theJAX Ecosystem sectionon the JAX documentation site for a list of JAX-based network libraries, which includesOptax for gradient processing andoptimization,chex for reliable code and testing, andEquinox for neural networks.(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talkhere for additional details.)

Citing JAX

To cite this repository:

@software{jax2018github,  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},  url = {http://github.com/jax-ml/jax},  version = {0.3.13},  year = {2018},}

In the above bibtex entry, names are in alphabetical order, the version numberis intended to be that fromjax/version.py, andthe year corresponds to the project's open-source release.

A nascent version of JAX, supporting only automatic differentiation andcompilation to XLA, was described in apaper that appeared at SysML2018. We're currently working oncovering JAX's ideas and capabilities in a more comprehensive and up-to-datepaper.

Reference documentation

For details about the JAX API, see thereference documentation.

For getting started as a JAX developer, see thedeveloper documentation.


[8]ページ先頭

©2009-2025 Movatter.jp