Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Ref: mutable arrays for data plumbing and memory control#

JAXArrays are immutable, representing mathematical values. Immutability canmake code easier to reason about, and is useful for optimized compilation,parallelization, rematerialization, and transformations like autodiff.

But immutability is constraining too:

  • expressiveness — plumbing out intermediate data or maintaining state,e.g. for normalization statistics or metrics, can feel heavyweight;

  • performance — it’s more difficult to reason about performance, likememory lifetimes and in-place updates.

Refs can help! They represent mutable arrays that can be read and writtenin-place. These array references are compatible with JAX transformations, likejax.jit andjax.grad:

importjaximportjax.numpyasjnpx_ref=jax.new_ref(jnp.zeros(3))# new array ref, with initial value [0., 0., 0.]@jax.jitdeff():x_ref[1]+=1.# indexed add-updateprint(x_ref)# Ref([0., 0., 0.])f()f()print(x_ref)# Ref([0., 2., 0.])
Ref([0., 0., 0.], dtype=float32)Ref([0., 2., 0.], dtype=float32)

The indexing syntax follows NumPy’s. For aRef calledx_ref, we canread its entire value into anArray by writingx_ref[...], and write itsentire value usingx_ref[...]=A for someArray-valued expressionA:

defg(x):x_ref=jax.new_ref(0.)x_ref[...]=jnp.sin(x)returnx_ref[...]print(jax.grad(g)(1.0))# 0.54
0.5403023

Ref is a distinct type fromArray, and it comes with some importantconstraints and limitations. In particular, indexed reading and writing is justabout theonly thing you can do with anRef. References can’t be passedwhereArrays are expected:

x_ref=jax.new_ref(1.0)try:jnp.sin(x_ref)# error! can't do math on refsexceptExceptionase:print(e)
Attempting to pass a Ref Ref{float32[]} to a primitive: sin -- did you forget to unpack ([...]) the ref?

To do math, you need to read the ref’s value first, likejnp.sin(x_ref[...]).

So whatcan you do withRef? Read on for the details, and some usefulrecipes.

API#

If you’ve ever usedPallas, thenRefshould look familiar. A big difference is that you can create newRefsyourself directly usingjax.new_ref:

fromjaximportArray,Refdefarray_ref(init_val:Array)->Ref:"""Introduce a new reference with given initial value."""

jax.freeze is its antithesis, invalidating the given ref (so that accessing itafterwards is an error) and producing its final value:

deffreeze(ref:Ref)->Array:"""Invalidate given reference and produce its final value."""

In between creating and destroying them, you can perform indexed reads andwrites on refs. You can read and write using the functionsjax.ref.get andjax.ref.swap, but usually you’d just use NumPy-style array indexing syntax:

importtypesIndex=int|slice|Array|types.EllipsisTypeIndexer=Index|tuple[Index,...]defget(ref:Ref,idx:Indexer)->Array:"""Returns `ref[idx]` for NumPy-style indexer `idx`."""defswap(ref:Ref,idx:Indexer,val:Array)->Array:"""Performs `newval, ref[idx] = ref[idx], val` and returns `newval`."""

Here,Indexer can be any NumPy indexing expression:

x_ref=jax.new_ref(jnp.arange(12.).reshape(3,4))# int indexingrow=x_ref[0]x_ref[1]=row# tuple indexingval=x_ref[1,2]x_ref[2,3]=val# slice indexingcol=x_ref[:,1]x_ref[0,:3]=col# advanced int array indexingvals=x_ref[jnp.array([0,0,1]),jnp.array([1,2,3])]x_ref[jnp.array([1,2,1]),jnp.array([0,0,1])]=vals

As withArrays, indexing mostly follows NumPy behavior, except forout-of-bounds indexing whichbehaves in the usual way for JAXArrays.

Pure and impure functions#

A function that takes a ref as an argument (either explicitly or by lexicalclosure) is consideredimpure. For example:

# takes ref as an argument => impure@jax.jitdefimpure1(x_ref,y_ref):x_ref[...]=y_ref[...]# closes over ref => impurey_ref=jax.new_ref(0)@jax.jitdefimpure2(x):y_ref[...]=x

If a function only uses refs internally, it is still consideredpure. Purityis in the eye of the caller. For example:

# internal refs => still pure@jax.jitdefpure1(x):ref=jax.new_ref(x)ref[...]=ref[...]+ref[...]returnref[...]

Pure functions, even those that use refs internally, are familiar: for example,they work with transformations likejax.grad,jax.vmap,jax.shard_map, andothers in the usual way.

Impure functions are sequenced in Python program order.

Restrictions#

Refs are second-class, in the sense that there are restrictions on theiruse:

  • Can’t return refs fromjit-decorated functions or the bodies ofhigher-order primitives likejax.lax.scan,jax.lax.while_loop, orjax.lax.cond

  • Can’t pass a ref as an argument more than once tojit-decoratedfunctions or higher-order primitives

  • Can onlyfreeze in creation scope

  • No higher-order refs (refs-to-refs)

For example, these are errors:

x_ref=jax.new_ref(0.)# can't return refs@jax.jitdeferr1(x_ref):x_ref[...]=5.returnx_ref# error!try:err1(x_ref)exceptExceptionase:print(e)# can't pass a ref as an argument more than once@jax.jitdeferr2(x_ref,y_ref):...try:err2(x_ref,x_ref)# error!exceptExceptionase:print(e)# can't pass and close over the same ref@jax.jitdeferr3(y_ref):y_ref[...]=x_ref[...]try:err3(x_ref)# error!exceptExceptionase:print(e)# can only freeze in creation scope@jax.jitdeferr4(x_ref):jax.freeze(x_ref)try:err4(x_ref)# error!exceptExceptionase:print(e)
function err1 at /tmp/ipykernel_1340/3915325362.py:4 traced for jit returned a mutable array reference of type Ref{float32[]} at output tree path result, but mutable array references cannot be returned.The returned mutable array was passed in as the argument x_ref.only one reference to a mutable array may be passed as an argument to a function, but when tracing err2 at /tmp/ipykernel_1340/3915325362.py:14 for jit the mutable array reference of type Ref{float32[]} appeared at both x_ref and y_ref.when tracing err3 at /tmp/ipykernel_1340/3915325362.py:23 for jit, a mutable array reference of type Ref{float32[]} was both closed over and passed as the argument y_ref

These restrictions exist to rule out aliasing, where two refs might refer to thesame mutable memory, making programs harder to reason about and transform.Weaker restrictions would also suffice, so some of these restrictions may belifted as we improve JAX’s ability to verify that no aliasing is present.

There are also restrictions stemming from undefined semantics, e.g. in thepresence of parallelism or rematerialization:

  • Can’tvmap orshard_map a function that closes over refs

  • Can’t applyjax.remat/jax.checkpoint to an impure function

For example, here are ways you can and can’t usevmap with impure functions:

# vmap over ref args is okaydefdist(x,y,out_ref):assertx.ndim==y.ndim==1assertout_ref.ndim==0out_ref[...]=jnp.sum((x-y)**2)vecs=jnp.arange(12.).reshape(3,4)out_ref=jax.new_ref(jnp.zeros((3,3)))jax.vmap(jax.vmap(dist,(0,None,0)),(None,0,0))(vecs,vecs,out_ref)# ok!print(out_ref)
Ref([[  0.,  64., 256.],       [ 64.,   0.,  64.],       [256.,  64.,   0.]], dtype=float32)
# vmap with a closed-over ref is notx_ref=jax.new_ref(0.)deferr5(x):x_ref[...]=xtry:jax.vmap(err5)(jnp.arange(3.))# error!exceptExceptionase:print(e)
performing a set/swap operation with vmapped value on an unbatched array reference of type Ref{float32[]}. Move the array reference to be an argument to the vmapped function?

The latter is an error because it’s not clear which valuex_ref should beafter we runjax.vmap(err5).

Refs and automatic differentiation#

Autodiff can be applied to pure functions as before, even if they use array refsinternally. For example:

@jax.jitdefpure2(x):ref=jax.new_ref(x)ref[...]=ref[...]+ref[...]returnref[...]print(jax.grad(pure1)(3.0))# 2.0
2.0

Autodiff can also be applied to functions that take array refs as arguments, ifthose arguments are only used for plumbing and not involved in differentiation:

# errordeferr6(x,some_plumbing_ref):y=x+xsome_plumbing_ref[...]+=yreturny# finedeffoo(x,some_plumbing_ref):y=x+xsome_plumbing_ref[...]+=jax.lax.stop_gradient(y)returny

You can combine plumbing refs withcustom_vjp to plumb data out of thebackward pass of a differentiated function:

# First, define the helper `stash_grads`:@jax.custom_vjpdefstash_grads(grads_ref,x):returnxdefstash_grads_fwd(grads_ref,x):returnx,grads_refdefstash_grads_bwd(grads_ref,g):grads_ref[...]=greturnNone,gstash_grads.defvjp(stash_grads_fwd,stash_grads_bwd)
# Now, use `stash_grads` to stash intermediate gradients:deff(x,grads_ref):x=jnp.sin(x)x=stash_grads(grads_ref,x)returnxgrads_ref=jax.new_ref(0.)f(1.,grads_ref)print(grads_ref)
Ref(0., dtype=float32, weak_type=True)

Noticestash_grads_fwd is returning aRef here. That’s a specialallowance forcustom_vjp fwd rules: it’s really syntax for indicating whichref arguments should be shared by both the fwd and bwd rules. So any refsreturned by a fwd rule must be arguments to that fwd rule.

Refs and performance#

At the top level, when callingjit-decorated functions,Refs obviatethe need for donation, since they are effectively always donated:

@jax.jitdefsin_inplace(x_ref):x_ref[...]=jnp.sin(x_ref[...])x_ref=jax.new_ref(jnp.arange(3.))print(x_ref.unsafe_buffer_pointer(),x_ref)sin_inplace(x_ref)print(x_ref.unsafe_buffer_pointer(),x_ref)
98085754563392 Ref([0., 1., 2.], dtype=float32)98085754563392 Ref([0.        , 0.84147096, 0.9092974 ], dtype=float32)

Heresin_inplace operates in-place, updating the buffer backingx_ref sothat its address stays the same.

Under ajit, you should expect array references to point to fixed bufferaddresses, and for indexed updates to be performed in-place.

Temporary caveat: dispatch from Python to impurejit-compiled functionsthat takeRef inputs is currently slower than dispatch to purejit-compiled functions, since it takes a less optimized path.

foreach, a new way to writescan#

As you may know,jax.lax.scan is a loop construct with a built-in fixed accesspattern for scanned-over inputs and outputs. The access pattern is built in forautodiff reasons: if we were instead to slice into immutable inputs directly,reverse-mode autodiff would end up creating one-hot gradients and summing themup, which can be asymptotically inefficient. SeeSec 5.3.3 of the Dexpaper.

But reading slices ofRefs doesn’t have this efficiency problem: when weapply reverse-mode autodiff, we always generate in-place accumulationoperations. As a result, we no longer need to be constrained byscan’s fixedaccess pattern. We can write more flexible loops, e.g. with non-sequentialaccess.

Moreover, having mutation available allows for some syntax tricks, like in thisrecipe for aforeach decorator:

importjaximportjax.numpyasjnpfromjax.laximportscandefforeach(*args):defdecorator(body):returnscan(lambda_,elts:(None,body(*elts)),None,args)[1]returndecorator
r=jax.new_ref(0)xs=jnp.arange(10)@foreach(xs)defys(x):r[...]+=xreturnx*2print(r)# Ref(45, dtype=int32)print(ys)# [ 0  2  4  6  8 10 12 14 16 18]
Ref(45, dtype=int32)[ 0  2  4  6  8 10 12 14 16 18]

Here, the loop runs immediately, updatingr in-place and bindingys to bethe mapped result.


[8]ページ先頭

©2009-2026 Movatter.jp