Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Key concepts#

This section briefly introduces some key concepts of the JAX package.

Transformations#

Along with functions to operate on arrays, JAX includes a number oftransformations which operate on JAX functions. These include

as well as several others. Transformations accept a function as an argument, and return anew transformed function. For example, here’s how you might JIT-compile a simple SELU function:

importjaximportjax.numpyasjnpdefselu(x,alpha=1.67,lambda_=1.05):returnlambda_*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)selu_jit=jax.jit(selu)print(selu_jit(1.0))
1.05

Often you’ll see transformations applied using Python’s decorator syntax for convenience:

@jax.jitdefselu(x,alpha=1.67,lambda_=1.05):returnlambda_*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)

Tracing#

The magic behind transformations is the notion of aTracer.Tracers are abstract stand-ins for array objects, and are passed to JAX functions in orderto extract the sequence of operations that the function encodes.

You can see this by printing any array value within transformed JAX code; for example:

@jax.jitdeff(x):print(x)returnx+1x=jnp.arange(5)result=f(x)
JitTracer(int32[5])

The value printed is not the arrayx, but aTracer instance thatrepresents essential attributes ofx, such as itsshape anddtype. By executingthe function with traced values, JAX can determine the sequence of operations encodedby the function before those operations are actually executed: transformations likejit(),vmap(), andgrad() can then map this sequenceof input operations to a transformed sequence of operations.

Static vs traced operations: Just as values can be either static or traced,operations can be static or traced. Static operations are evaluated at compile-timein Python; traced operations are compiled & evaluated at run-time in XLA.

For more details, seeTracing.

Jaxprs#

JAX has its own intermediate representation for sequences of operations, known as ajaxpr.A jaxpr (short forJAX exPRession) is a simple representation of a functional program, comprising a sequence ofprimitive operations.

For example, consider theselu function we defined above:

defselu(x,alpha=1.67,lambda_=1.05):returnlambda_*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)

We can use thejax.make_jaxpr() utility to convert this function into a jaxprgiven a particular input:

x=jnp.arange(5.0)jax.make_jaxpr(selu)(x)
{lambda; a:f32[5].letb:bool[5] = gt a 0.0:f32[]    c:f32[5] = exp a    d:f32[5] = mul 1.6699999570846558:f32[] c    e:f32[5] = sub d 1.6699999570846558:f32[]    f:f32[5] = jit[      name=_where      jaxpr={lambda; b:bool[5] a:f32[5] e:f32[5].letf:f32[5] = select_n b e ain(f,) }    ] b a e    g:f32[5] = mul 1.0499999523162842:f32[] fin(g,) }

Comparing this to the Python function definition, we see that it encodes the precisesequence of operations that the function represents. We’ll go into more depth aboutjaxprs later inJAX internals: The jaxpr language.

Pytrees#

JAX functions and transformations fundamentally operate on arrays, but in practice it isconvenient to write code that works with collection of arrays: for example, a neuralnetwork might organize its parameters in a dictionary of arrays with meaningful keys.Rather than handle such structures on a case-by-case basis, JAX relies on thepytreeabstraction to treat such collections in a uniform manner.

Here are some examples of objects that can be treated as pytrees:

# (nested) list of parametersparams=[1,2,(jnp.arange(3),jnp.ones(2))]print(jax.tree.structure(params))print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parametersparams={'n':5,'W':jnp.ones((2,2)),'b':jnp.zeros(2)}print(jax.tree.structure(params))print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})[Array([[1., 1.],       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parametersfromtypingimportNamedTupleclassParams(NamedTuple):a:intb:floatparams=Params(1,5.0)print(jax.tree.structure(params))print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))[1, 5.0]

JAX has a number of general-purpose utilities for working with PyTrees; for examplethe functionsjax.tree.map() can be used to map a function to every leaf in atree, andjax.tree.reduce() can be used to apply a reduction across the leavesin a tree.

You can learn more in thePytrees tutorial.

JAX API layering: NumPy, lax & XLA#

All JAX operations are implemented in terms of operations inXLA – the Accelerated Linear Algebra compiler. If you look at the source ofjax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined injax.lax. Whilejax.numpy is a high-level wrapper that provides a familiar interface, you can think ofjax.lax as a stricter, but often more powerful, lower-level API for working with multi-dimensional arrays.

For example, whilejax.numpy will implicitly promote arguments to allow operations between mixed data types,jax.lax will not:

importjax.numpyasjnpjnp.add(1,1.0)# jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
fromjaximportlaxlax.add(1,1.0)# jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------TypeErrorTraceback (most recent call last)CellIn[10],line21fromjaximportlax---->2lax.add(1,1.0)# jax.lax API requires explicit type promotion.File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:1159, inadd(x, y)1139r"""Elementwise addition: :math:`x + y`.11401141 This function lowers directly to the `stablehlo.add`_ operation.   (...)   1156 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add1157 """1158x,y=core.standard_insert_pvary(x,y)->1159returnadd_p.bind(x,y)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:636, inPrimitive.bind(self, *args, **params)634defbind(self,*args,**params):635args=argsifself.skip_canonicalizationelsemap(canonicalize_value,args)-->636returnself._true_bind(*args,**params)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:652, inPrimitive._true_bind(self, *args, **params)650trace_ctx.set_trace(eval_trace)651try:-->652returnself.bind_with_trace(prev_trace,args,params)653finally:654trace_ctx.set_trace(prev_trace)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:664, inPrimitive.bind_with_trace(self, trace, args, params)662withset_current_trace(trace):663returnself.to_lojax(*args,**params)# type: ignore-->664returntrace.process_primitive(self,args,params)665trace.process_primitive(self,args,params)# may raise lojax error666raiseException(f"couldn't apply typeof to args:{args}")File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1208, inEvalTrace.process_primitive(self, primitive, args, params)1206args=map(full_lower,args)1207check_eval_args(args)->1208returnprimitive.impl(*args,**params)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/dispatch.py:91, inapply_primitive(prim, *args, **params)89prev=config.disable_jit.swap_local(False)90try:--->91outs=fun(*args)92finally:93config.disable_jit.set_local(prev)[...skippinghidden18frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:8752, incheck_same_dtypes(name, *avals)8750equiv=_JNP_FUNCTION_EQUIVALENTS[name]8751msg+=f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."->8752raiseTypeError(msg.format(name,", ".join(str(a.dtype)forainavals)))TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

If usingjax.lax directly, you’ll have to do type promotion explicitly in such cases:

lax.add(jnp.float32(1),1.0)
Array(2., dtype=float32)

Along with this strictness,jax.lax also provides efficient APIs for some more general operations than are supported by NumPy.

For example, consider a 1D convolution, which can be expressed in NumPy this way:

x=jnp.array([1,2,1])y=jnp.ones(10)jnp.convolve(x,y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

Under the hood, this NumPy operation is translated to a much more general convolution implemented bylax.conv_general_dilated:

fromjaximportlaxresult=lax.conv_general_dilated(x.reshape(1,1,3).astype(float),# note: explicit promotiony.reshape(1,1,10),window_strides=(1,),padding=[(len(y)-1,len(y)-1)])# equivalent of padding='full' in NumPyresult[0,0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (SeeConvolutions in JAX for more detail on JAX convolutions).

At their heart, alljax.lax operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided byXLA:ConvWithGeneralPadding.Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation.


[8]ページ先頭

©2009-2026 Movatter.jp