Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Quickstart: How to think in JAX#

Open in ColabOpen in Kaggle

JAX is a library for array-oriented numerical computation (à laNumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

This document provides a quick overview of essential JAX features, so you can get started with JAX:

  • JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.

  • JAX features built-in Just-In-Time (JIT) compilation viaOpen XLA, an open-source machine learning compiler ecosystem.

  • JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.

  • JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

Installation#

JAX can be installed for CPU on Linux, Windows, and macOS directly from thePython Package Index:

pipinstalljax

or, for NVIDIA GPU:

pipinstall-U"jax[cuda13]"

For more detailed platform-specific installation information, check outInstallation.

JAX vs. NumPy#

Key concepts:

  • JAX provides a NumPy-inspired interface for convenience.

  • Throughduck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.

  • Unlike NumPy arrays, JAX arrays are always immutable.

NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX providesjax.numpy which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done withnumpy can be done withjax.numpy, which is typically imported under thejnp alias:

importjax.numpyasjnp

With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:

importmatplotlib.pyplotaspltx_jnp=jnp.linspace(0,10,1000)y_jnp=2*jnp.sin(x_jnp)*jnp.cos(x_jnp)plt.plot(x_jnp,y_jnp);
../_images/4498d5e153c1a8c927aafba90b5d380739ae80df1d03dbf86b79dc33d0ceaa5c.png

The code blocks are identical to what you would expect with NumPy, aside from replacingnp withjnp, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.

The arrays themselves are implemented as different Python types:

importnumpyasnpimportjax.numpyasjnpx_np=np.linspace(0,10,1000)x_jnp=jnp.linspace(0,10,1000)
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib._jax.ArrayImpl

Python’s duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places. However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.

Here is an example of mutating an array in NumPy:

# NumPy: mutable arraysx=np.arange(10)x[0]=10print(x)
[10  1  2  3  4  5  6  7  8  9]

The equivalent in JAX results in an error, as JAX arrays are immutable:

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arraysx=jnp.arange(10)x[0]=10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

For updating individual elements, JAX provides anindexed update syntax that returns an updated copy:

y=x.at[0].set(10)print(x)print(y)
[0 1 2 3 4 5 6 7 8 9][10  1  2  3  4  5  6  7  8  9]

You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging in. See also:

JAX arrays (jax.Array)#

Key concepts:

  • Create arrays using JAX API functions.

  • JAX array objects have adevices attribute that indicates where the array is stored.

  • JAX arrays can besharded across multiple devices for parallel computation.

The default array implementation in JAX isjax.Array. In many ways it is similar tothenumpy.ndarray type that you may be familiar with from the NumPy package, but ithas some important differences.

Array creation#

We typically don’t call thejax.Array constructor directly, but rather create arrays via JAX API functions.For example,jax.numpy provides familiar NumPy-style array construction functionalitysuch asjax.numpy.zeros,jax.numpy.linspace,jax.numpy.arange, etc.

importjaximportjax.numpyasjnpx=jnp.arange(5)isinstance(x,jax.Array)
True

If you use Python type annotations in your code,jax.Array is the appropriateannotation for jax array objects (seejax.typing for more discussion).

Array devices and sharding#

JAX Array objects have adevices method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:

x.devices()
{CpuDevice(id=0)}

In general, an array may besharded across multiple devices, in a manner that can be inspected via thesharding attribute:

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

Here the array is on a single device, but in general a JAX array can besharded across multiple devices, or even multiple hosts.To read more about sharded arrays and parallel computation, refer toIntroduction to parallel programming.

Just-in-time compilation withjax.jit#

Key concepts:

  • By default JAX executes operations one at a time, in sequence.

  • Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

  • Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one), with all JAX operations being expressed in terms of XLA. If we have a sequence of operations, we can use thejax.jit function to compile this sequence of operations together using the XLA compiler.

For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms ofjax.numpy operations:

importjax.numpyasjnpdefnorm(X):X=X-X.mean(0)returnX/X.std(0)

A just-in-time compiled version of the function can be created using thejax.jit transform:

fromjaximportjitnorm_compiled=jit(norm)

This function returns the same results as the original, up to standard floating-point accuracy:

np.random.seed(1701)X=jnp.array(np.random.rand(10000,10))np.allclose(norm(X),norm_compiled(X),atol=1E-6)
True

But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case. We can use IPython’s%timeit to quickly benchmark our function, usingblock_until_ready() to account for JAX’sasynchronous dispatch:

%timeit norm(X).block_until_ready()%timeit norm_compiled(X).block_until_ready()
217 μs ± 21 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)205 μs ± 2.17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

That said,jax.jit does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.

For example, this operation can be executed in op-by-op mode:

defget_negatives(x):returnx[x<0]x=jnp.array(np.random.randn(10))get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

But it returns an error if you attempt to execute it in jit mode:

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT.

For more on JIT compilation in JAX, check outJust-in-time compilation.

Taking derivatives withjax.grad#

Key concepts:

  • JAX provides automatic differentiation via thejax.grad transformation.

  • Thejax.grad andjax.jit transformations compose and can be mixed arbitrarily.

In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation isjax.grad, which performsautomatic differentiation (autodiff):

fromjaximportgraddefsum_logistic(x):returnjnp.sum(1.0/(1.0+jnp.exp(-x)))x_small=jnp.arange(3.)derivative_fn=grad(sum_logistic)print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Let’s verify with finite differences that our result is correct.

deffirst_finite_differences(f,x,eps=1E-3):returnjnp.array([(f(x+eps*v)-f(x-eps*v))/(2*eps)forvinjnp.eye(len(x))])print(first_finite_differences(sum_logistic,x_small))
[0.24998187 0.1964569  0.10502338]

Thejax.grad andjax.jit transformations compose and can be mixed arbitrarily.For instance, while thesum_logistic function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

Beyond scalar-valued functions, thejax.jacobian transformation can beused to compute the full Jacobian matrix for vector-valued functions:

fromjaximportjacobianprint(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ] [0.        2.7182817 0.       ] [0.        0.        7.389056 ]]

For more advanced autodiff operations, you can usejax.vjp for reverse-mode vector-Jacobian products,andjax.jvp andjax.linearize for forward-mode Jacobian-vector products.The two can be composed arbitrarily with one another, and with other JAX transformations.For example,jax.jvp andjax.vjp are used to define the forward-modejax.jacfwd and reverse-modejax.jacrev for computing Jacobians in forward- and reverse-mode, respectively.Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

fromjaximportjacfwd,jacrevdefhessian(fun):returnjit(jacfwd(jacrev(fun)))print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ] [-0.         -0.09085776 -0.        ] [-0.         -0.         -0.07996249]]

This kind of composition produces efficient code in practice; this is more-or-less how JAX’s built-injax.hessian function is implemented.

For more on automatic differentiation in JAX, check outAutomatic differentiation.

Auto-vectorization withjax.vmap#

Key concepts:

  • JAX provides automatic vectorization via thejax.vmap transformation.

  • jax.vmap can be composed withjax.jit to produce efficient vectorized code.

Another useful transformation isjax.vmap, the vectorizing map.It has the familiar semantics of mapping a function along array axes, but instead of explicitly loopingover function calls, it transforms the function into a natively vectorized version for better performance.When composed withjax.jit, it can be just as performant as manually rewriting your functionto operate over an extra batch dimension.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products usingjax.vmap.Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

fromjaximportrandomkey=random.key(1701)key1,key2=random.split(key)mat=random.normal(key1,(150,100))batched_x=random.normal(key2,(10,100))defapply_matrix(x):returnjnp.dot(mat,x)

Theapply_matrix function maps a vector to a vector, but we may want to apply it row-wise across a matrix.We could do this by looping over the batch dimension in Python, but this usually results in poor performance.

defnaively_batched_apply_matrix(v_batched):returnjnp.stack([apply_matrix(v)forvinv_batched])print('Naively batched')%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched417 μs ± 2.86 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

A programmer familiar with thejnp.dot function might recognize thatapply_matrix canbe rewritten to avoid explicit looping, using the built-in batching semantics ofjnp.dot:

importnumpyasnp@jitdefbatched_apply_matrix(batched_x):returnjnp.dot(batched_x,mat.T)np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),batched_apply_matrix(batched_x),atol=1E-4,rtol=1E-4)print('Manually batched')%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched13.8 μs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone.Thejax.vmap transformation is designed to automatically transform a function into a batch-aware version:

fromjaximportvmap@jitdefvmap_batched_apply_matrix(batched_x):returnvmap(apply_matrix)(batched_x)np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),vmap_batched_apply_matrix(batched_x),atol=1E-4,rtol=1E-4)print('Auto-vectorized with vmap')%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap14.6 μs ± 90.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

As you would expect,jax.vmap can be arbitrarily composed withjax.jit,jax.grad, and any other JAX transformation.

For more on automatic vectorization in JAX, check outAutomatic vectorization.

Pseudorandom numbers#

Key concepts:

  • JAX uses a different model for pseudo random number generation than NumPy.

  • JAX random functions consume a randomkey that must be split to generate new independent keys.

  • JAX’s random key model is thread-safe and avoids issues with global state.

Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a globalstate, which can be set usingnumpy.random.seed. Global random state interacts poorly with JAX’s compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a randomkey:

fromjaximportrandomkey=random.key(43)print(key)
Array((), dtype=key<fry>) overlaying:[ 0 43]

The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly tojax.random functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.

print(random.normal(key))print(random.normal(key))
0.075205430.07520543

The rule of thumb is: never reuse keys (unless you want identical outputs).

In order to generate different and independent samples, you mustjax.random.split the key explicitly before passing it to a random function:

foriinrange(3):new_key,subkey=random.split(key)delkey# The old key is consumed by split() -- we must never use it again.val=random.normal(subkey)delsubkey# The subkey is consumed by normal().print(f"draw{i}:{val}")key=new_key# new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951draw 1: -1.4749839305877686draw 2: -0.36703771352767944

Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state.jax.random.split is a deterministic function that converts onekey into several independent (in the pseudorandomness sense) keys.

For more on pseudo random numbers in JAX, see thePseudorandom numbers tutorial.

Debugging#

Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging.

jax.debug.print#

For simple inspection, usejax.debug.print.

Python’s built-inprint executes at trace-time, before the runtime values exist. Because of this,print will only show tracer values withinjax.jit-decorated code.

importjaximportjax.numpyasjnp@jax.jitdeff(x):print("print(x) ->",x)y=jnp.sin(x)print("print(y) ->",y)returnyresult=f(2.)
print(x) -> JitTracer(~float32[])print(y) -> JitTracer(~float32[])

If you want to print the actual runtime values, you can usejax.debug.print:

@jax.jitdeff(x):jax.debug.print("jax.debug.print(x) ->{x}",x=x)y=jnp.sin(x)jax.debug.print("jax.debug.print(y) ->{y}",y=y)returnyresult=f(2.)
jax.debug.print(x) -> 2.0jax.debug.print(y) -> 0.9092974066734314

Debugging flags#

JAX offers flags and context managers that enable catching errors more easily. For example, you can enable thejax.debug_nans flag to automatically detect when NaNs are produced injax.jit-compiled code. You can also enable thejax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools likeprint andpdb.

For more details, seeIntroduction to debugging.


This is just a taste of what JAX can do. We’re really excited to see what you do with it!


[8]ページ先頭

©2009-2026 Movatter.jp