Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Omnistaging#

mattjj@Sept 25 2020

This is more of an upgrade guide than a design doc.

Contents#

tl;dr#

What’s going on?#

A change to JAX’s tracing infrastructure called “omnistaging”(jax-ml/jax#3370) was switched on injax==0.2.0. This change improves memory performance, trace execution time, andsimplifies jax internals, but may cause some existing code to break. Breakage isusually a result of buggy code, so long-term it’s best to fix the bugs, butomnistaging can also be disabled as a temporary workaround. And we’re happy tohelp you with fixes!

How do I know if omnistaging broke my code?#

The easiest way to tell if omnistaging is responsible is to disable omnistagingand see if the issues go away. See theWhat issues can arise when omnistagingis switched on? sectionbelow.

How can I disable omnistaging for now?#

Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot bedisabled in JAX versions 0.2.12 and higher

It is temporarily possible to disable omnistaging by

  1. setting the shell environment variableJAX_OMNISTAGING to something falsey;

  2. setting the boolean flagjax_omnistaging to something falsey if your codeparses flags with absl;

  3. using this statement near the top of your main file:

jax.config.disable_omnistaging()

How do I fix bugs exposed by omnistaging?#

By far the most common issue with omnistaging is usingjax.numpy to computeshape values or other trace-time constants. See the code block below for a quickexample, and for full details along with other issues see the sectionWhatissues can arise when omnistaging is switchedon?.

Instead of this:

@jitdeff(x):input_size=jnp.prod(x.shape)ifinput_size>100:...

do this:

importnumpyasnp@jitdeff(x):input_size=np.prod(x.shape)ifinput_size>100:...

Instead of thinking ofjax.numpy as a drop-in replacement fornumpy, it’snow better to think of usingjax.numpy operations only when you want to perform acomputation on an accelerator (like your GPU).

What is “omnistaging” and why is it useful?#

Omnistaging is the name for a JAX core upgrade aimed at staging out morecomputation from op-by-op Python to XLA, and avoiding any “trace-time constantfolding” injit,pmap, and control flow primitives. As a result, omnistagingimproves JAX’s memory performance (sometimes dramatically) both by reducingfragmentation during tracing and by producing fewer large compile-time constantsfor XLA. It can also improve tracing performance by eliminating op-by-opexecution at tracing time. Further, omnistaging simplifies JAX core internals,fixing many outstanding bugs and setting the stage for important upcomingfeatures.

The name “omnistaging” means staging out everything possible.

Toy example#

JAX transformations likejit andpmap stage out computations to XLA. Thatis, we apply them to functions comprising multiple primitive operations so thatrather being executed one at a time from Python the operations are all part ofone end-to-end optimized XLA computation.

But exactly which operations get staged out? Until omnistaging, JAX staged outcomputation based on data dependence only. Here’s an example function, followedby the XLA HLO program it stages outbefore the omnistaging change:

fromjaximportjitimportjax.numpyasjnp@jitdeff(x):y=jnp.add(1,1)returnx*yf(3)
ENTRYjit_f.6{constant.2=pred[]constant(false)parameter.1=s32[]parameter(0)constant.3=s32[]constant(2)multiply.4=s32[]multiply(parameter.1,constant.3)ROOTtuple.5=(s32[])tuple(multiply.4)}

Notice that theadd operation is not staged out. Instead, we only see amultiply.

Here’s the HLO generated from this functionafter the omnistaging change:

ENTRYjit_f.8{constant.2=pred[]constant(false)parameter.1=s32[]parameter(0)constant.3=s32[]constant(1)constant.4=s32[]constant(1)add.5=s32[]add(constant.3,constant.4)multiply.6=s32[]multiply(parameter.1,add.5)ROOTtuple.7=(s32[])tuple(multiply.6)}

Slightly less toy example#

Here’s a less toy example which can arise in practice when we want to createboolean masks:

importjax.numpyasjnpfromjaximportlax@jitdefselect_tril(x):mask=jnp.arange(x.shape[0])[:,None]>jnp.arange(x.shape[1])returnlax.select(mask,x,jnp.zeros_like(x))# lax.select is like jnp.wherex=np.arange(12).reshape((3,4))select_tril(x)

Before omnistaging:

ENTRYjit_select_tril.8{constant.3=pred[]constant(false)constant.1=pred[3,4]{1,0}constant({...})parameter.2=s32[3,4]{1,0}parameter(0)constant.4=s32[]constant(0)broadcast.5=s32[3,4]{1,0}broadcast(constant.4),dimensions={}select.6=s32[3,4]{1,0}select(constant.1,parameter.2,broadcast.5)ROOTtuple.7=(s32[3,4]{1,0})tuple(select.6)}

Theselect operation is staged out, but the operations for constructing theconstantmask are not. Rather than being staged out, the operations thatconstructmask are executed op-by-op at Python tracing time, and XLA only seesa compile time constantconstant.1 representing the value ofmask. That’sunfortunate, because if we had staged out the operations for constructingmask, XLA could have fused them into theselect and avoided materializingthe result at all. As a result we end up wasting memory with a potentially-largeconstant, wasting time dispatching multiple un-fused op-by-op XLA computations,and potentially even fragmenting memory.

(Thebroadcast that corresponds to the construction of the zeros array forjnp.zeros_like(x) is staged out because JAX is lazy about very simpleexpressions fromjax-ml/jax#1668. Afteromnistaging, we can remove that lazy sublanguage and simplify JAX internals.)

The reason the creation ofmask is not staged out is that, before omnistaging,jit operates based on data dependence. That is,jit stages out only thoseoperations in a function that have a data dependence on an argument. Controlflow primitives andpmap behave similarly. In the case ofselect_tril, theoperations to construct the constantmask do not have a data dependence on theargument x, so they are not staged out; only thelax.select call has a datadependence.

With omnistaging alljax.numpy calls in the dynamic context of ajit-transformed function are staged out to XLA. That is, after omnistaging thecomputation XLA sees forselect_tril is

ENTRYjit_select_tril.16{constant.4=pred[]constant(false)iota.1=s32[3]{0}iota(),iota_dimension=0broadcast.5=s32[3,1]{1,0}broadcast(iota.1),dimensions={0}reshape.7=s32[3]{0}reshape(broadcast.5)broadcast.8=s32[3,4]{1,0}broadcast(reshape.7),dimensions={0}iota.2=s32[4]{0}iota(),iota_dimension=0broadcast.6=s32[1,4]{1,0}broadcast(iota.2),dimensions={1}reshape.9=s32[4]{0}reshape(broadcast.6)broadcast.10=s32[3,4]{1,0}broadcast(reshape.9),dimensions={1}compare.11=pred[3,4]{1,0}compare(broadcast.8,broadcast.10),direction=GTparameter.3=s32[3,4]{1,0}parameter(0)constant.12=s32[]constant(0)broadcast.13=s32[3,4]{1,0}broadcast(constant.12),dimensions={}select.14=s32[3,4]{1,0}select(compare.11,parameter.3,broadcast.13)ROOTtuple.15=(s32[3,4]{1,0})tuple(select.14)}

What issues can arise when omnistaging is switched on?#

As a consequence of staging out alljax.numpy operations from Python to XLAwhen in the dynamic context of ajit orpmap, some code that workedpreviously can start raising loud errors. As explained below, these behaviorswere already buggy before omnistaging, but omnistaging makes them into harderrors.

Usingjax.numpy for shape computations#

Example#

fromjaximportjitimportjax.numpyasjnp@jitdefex1(x):size=jnp.prod(jnp.array(x.shape))returnx.reshape((size,))ex1(jnp.ones((3,4)))

Error message#

[... full traceback ...]  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error    raise ConcretizationTypeError(msg)jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.The error arose in jax.numpy.reshape.While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]    from line ex1.py:6 (ex1)You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

Explanation#

With omnistaging, we can’t usejax.numpy for shape computations as in the useofjnp.prod above because in the dynamic context of a jit function thoseoperations will be staged out of Python as values to be computed at executiontime, yet we need them to be compile-time (and hence trace-time) constants.

Before omnistaging, this code wouldn’t have raised an error, but it was a commonperformance bug: thejnp.prod computation would have been executed on thedevice at tracing time, meaning extra compilation, transfers, synchronization,allocations, and potentially memory fragmentation.

Solution#

The solution is simply to use the originalnumpy for shape calculations likethese. Not only do we avoid the error, but also we keep the computations on thehost (and with lower overheads).

This issue was common enough in code that we tried to make the errormessage especially good. In addition to the stack trace showing where anabstract tracer value caused a problem (thejnp.reshape line in the full stacktrace, on omni.py:10), we also explain why this value became a tracer in thefirst place by pointing to the upstream primitive operation that caused it tobecome an abstract tracer (thereduce_prod fromjnp.prod on omni.py:9) and towhichjit-decorated function the tracer belongs (ex1 on omni.py:6).

Side-effects#

Example#

fromjaximportjitfromjaximportrandomkey=random.PRNGKey(0)definit():globalkeykey,subkey=random.split(key)returnrandom.normal(subkey,())print(init())# -1.2515389print(init())# -0.58665067init=jit(init)print(init())# 0.48648298print(init())# 0.48648298  !!

That last call has repeated randomness but no hard error, because we aren’tre-executing the Python. But if we look atkey, we see an escaped tracerwhenomnistaging is on:

print(key)# Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

Before omnistaging, therandom.split call would not be staged out and so wewouldn’t get an escaped tracer. The code would still be buggy in that the jittedfunction wouldn’t be reproducing the semantics of the original function (becauseof the repeated use of the same PRNG key), ultimately due to the side effect.

With omnistaging on, if we touchkey again, we’ll get an escaped tracer error:

random.normal(key,())

Error message#

[... full stack trace …]  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live    raise core.escaped_tracer_error(msg)jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

Explanation#

The second largest category of omnistaging issues we found had to do withside-effecting code. This code already voided the JAX warranty by transformingeffectful functions, but due to pre-omnistaging “trace-time constant folding”behavior, some side effecting functions could nevertheless behave correctly.Omnistaging catches more of these errors.

Solution#

The solution is to identify JAX-transformed functions that rely on side effects,and to rewrite them not to be effectful.

Small numerical differences based on XLA optimizations#

Because with omnistaging more computations are being staged out to XLA, ratherthan some being executed at trace time, that can have the effect of reorderingfloating point operations. As a result, we’ve seen numerical behaviors change ina way that causes tests with overly tight tolerances to fail when omnistaging isswitched on.

Dependence on JAX internal APIs that changed#

Omnistaging involved some big revisions to JAX’s core code, including removingor changing internal functions. Any code that relies on such internalJAX APIs can break when omnistaging is switched on, either with build errors(from pytype) or runtime errors.

Triggering XLA compile time bugs#

Because omnistaging involves staging out more code to XLA, we’ve seen it triggerpre-existing XLA compile-time bugs on some backends. The best thing to do withthese is to report them so we can work with the XLA teams on fixes.


[8]ページ先頭

©2009-2026 Movatter.jp