Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Errors#

This page lists a few of the errors you might encounter when using JAX,along with representative examples of how one might fix them.

classjax.errors.JaxRuntimeError#

Runtime errors thrown by the JAX runtime. While the JAX runtime may raise other exceptions as well, most exceptions thrown by the runtime are instances of this class.

classjax.errors.JAXTypeError(message)#

JAX-specificTypeError

Parameters:

message (str)

classjax.errors.JAXIndexError(message)#

JAX-specificIndexError

Parameters:

message (str)

classjax.errors.ConcretizationTypeError(tracer,context='')#

This error occurs when a JAX Tracer object is used in a context where aconcrete value is required (seeDifferent kinds of JAX valuesfor more on what a Tracer is). In some situations, it can be easily fixed bymarking problematic values as static; in others, it may indicate that yourprogram is doing operations that are not directly supported by JAX’s JITcompilation model.

Examples:

Traced value where static value is expected

One common cause of this error is using a traced value where a static valueis required. For example:

>>>fromfunctoolsimportpartial>>>fromjaximportjit>>>importjax.numpyasjnp>>>@jit...deffunc(x,axis):...returnx.min(axis)
>>>func(jnp.arange(4),0)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concretevalue is expected: axis argument to jnp.min().

This can often be fixed by marking the problematic argument as static:

>>>@partial(jit,static_argnums=1)...deffunc(x,axis):...returnx.min(axis)>>>func(jnp.arange(4),0)Array(0, dtype=int32)
Shape depends on Traced Value

Such an error may also arise when a shape in your JIT-compiled computationdepends on the values within a traced quantity. For example:

>>>@jit...deffunc(x):...returnjnp.where(x<0)>>>func(jnp.arange(4))Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected:The error arose in jnp.nonzero.

This is an example of an operation that is incompatible with JAX’s JITcompilation model, which requires array sizes to be known at compile-time.Here the size of the returned array depends on the contents ofx, and suchcode cannot be JIT compiled.

In many cases it is possible to work around this by modifying the logic usedin the function; for example here is code with a similar issue:

>>>@jit...deffunc(x):...indices=jnp.where(x>1)...returnx[indices].sum()>>>func(jnp.arange(4))Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concretevalue is expected: The error arose in jnp.nonzero.

And here is how you might express the same operation in a way that avoidscreation of a dynamically-sized index array:

>>>@jit...deffunc(x):...returnjnp.where(x>1,x,0).sum()>>>func(jnp.arange(4))Array(5, dtype=int32)

To understand more subtleties having to do with tracers vs. regular values,and concrete vs. abstract values, you may want to readDifferent kinds of JAX values.

Parameters:
  • tracer (core.Tracer)

  • context (str)

classjax.errors.KeyReuseError(message)#

This error occurs when a PRNG key is reused in an unsafe manner.Key reuse is checked only whenjax_debug_key_reuse isset toTrue.

Here is a simple example of code that would lead to such an error:

>>>withjax.debug_key_reuse(True):...key=jax.random.key(0)...value=jax.random.uniform(key)...new_value=jax.random.uniform(key)...---------------------------------------------------------------------------KeyReuseError                             Traceback (most recent call last)...KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

This sort of key reuse is problematic because the JAX PRNG is stateless, and keysmust be manually split; For more information on this seethe Pseudorandom Numberstutorial.

Parameters:

message (str)

classjax.errors.NonConcreteBooleanIndexError(tracer)#

This error occurs when a program attempts to use non-concrete boolean indicesin a traced indexing operation. Under JIT compilation, JAX arrays must havestatic shapes (i.e. shapes that are known at compile-time) and so booleanmasks must be used carefully. Some logic implemented via boolean masking issimply not possible in ajax.jit() function; in other cases, the logiccan be re-expressed in a JIT-compatible way, often using the three-argumentversion ofwhere().

Following are a few examples of when this error might arise.

Constructing arrays via boolean masking

This most commonly arises when attempting to create an array via a booleanmask within a JIT context. For example:

>>>importjax>>>importjax.numpyasjnp>>>@jax.jit...defpositive_values(x):...returnx[x>0]>>>positive_values(jnp.arange(-5,5))Traceback (most recent call last):...NonConcreteBooleanIndexError:Array boolean indices must be concrete: ShapedArray(bool[10])

This function is attempting to return only the positive values in the inputarray; the size of this returned array cannot be determined at compile-timeunlessx is marked as static, and so operations like this cannot beperformed under JIT compilation.

Reexpressible Boolean Logic

Although creating dynamically sized arrays is not supported directly, inmany cases it is possible to re-express the logic of the computation interms of a JIT-compatible operation. For example, here is another functionthat fails under JIT for the same reason:

>>>@jax.jit...defsum_of_positive(x):...returnx[x>0].sum()>>>sum_of_positive(jnp.arange(-5,5))Traceback (most recent call last):...NonConcreteBooleanIndexError:Array boolean indices must be concrete: ShapedArray(bool[10])

In this case, however, the problematic array is only an intermediate value,and we can instead express the same logic in terms of the JIT-compatiblethree-argument version ofjax.numpy.where():

>>>@jax.jit...defsum_of_positive(x):...returnjnp.where(x>0,x,0).sum()>>>sum_of_positive(jnp.arange(-5,5))Array(10, dtype=int32)

This pattern of replacing boolean masking with three-argumentwhere() is a common solution to this sort of problem.

Boolean indexing into JAX arrays

The other situation where this error often arises is when using booleanindices, such as with.at[...].set(...). Here is a simple example:

>>>@jax.jit...defmanual_clip(x):...returnx.at[x<0].set(0)>>>manual_clip(jnp.arange(-2,2))Traceback (most recent call last):...NonConcreteBooleanIndexError:Array boolean indices must be concrete: ShapedArray(bool[4])

This function is attempting to set values smaller than zero to a scalar fillvalue. As above, this can be addressed by re-expressing the logic in termsofwhere():

>>>@jax.jit...defmanual_clip(x):...returnjnp.where(x<0,0,x)>>>manual_clip(jnp.arange(-2,2))Array([0, 0, 0, 1], dtype=int32)
Parameters:

tracer (core.Tracer)

classjax.errors.TracerArrayConversionError(tracer)#

This error occurs when a program attempts to convert a JAX Tracer object intoa standard NumPy array (seeDifferent kinds of JAX values for moreon what a Tracer is). It typically occurs in one of a few situations.

Using non-JAX functions in JAX transformations

This error can occur if you attempt to use a non-JAX library likenumpyorscipy inside a JAX transformation (jit(),grad(),jax.vmap(), etc.). For example:

>>>fromjaximportjit>>>importnumpyasnp>>>@jit...deffunc(x):...returnnp.sin(x)>>>func(np.arange(4))Traceback (most recent call last):...TracerArrayConversionError:The numpy.ndarray conversion method__array__() was called on traced array with shape int32[4]

In this case, you can fix the issue by usingjax.numpy.sin() in place ofnumpy.sin():

>>>importjax.numpyasjnp>>>@jit...deffunc(x):...returnjnp.sin(x)>>>func(jnp.arange(4))Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

See alsoExternal Callbacks for options for calling back to host-side computationsfrom transformed JAX code.

Indexing a numpy array with a tracer

If this error arises on a line that involves array indexing, it may be thatthe array being indexedx is a standard numpy.ndarray while the indicesidx are traced JAX arrays. For example:

>>>x=np.arange(10)>>>@jit...deffunc(i):...returnx[i]>>>func(0)Traceback (most recent call last):...TracerArrayConversionError:The numpy.ndarray conversion method__array__() was called on traced array with shape int32[0]

Depending on the context, you may fix this by converting the numpy arrayinto a JAX array:

>>>@jit...deffunc(i):...returnjnp.asarray(x)[i]>>>func(0)Array(0, dtype=int32)

or by declaring the index as a static argument:

>>>fromfunctoolsimportpartial>>>@partial(jit,static_argnums=(0,))...deffunc(i):...returnx[i]>>>func(0)Array(0, dtype=int32)

To understand more subtleties having to do with tracers vs. regular values,and concrete vs. abstract values, you may want to readDifferent kinds of JAX values.

Parameters:

tracer (core.Tracer)

classjax.errors.TracerBoolConversionError(tracer)#

This error occurs when a traced value in JAX is used in a context where aboolean value is expected (seeDifferent kinds of JAX valuesfor more on what a Tracer is).

The boolean cast may be an explicit (e.g.bool(x)) or implicit, through use ofcontrol flow (e.g.ifx>0 orwhilex), use of Python booleanoperators (e.g.z=xandy,z=xory,z=notx) or functionsthat use them (e.g.z=max(x,y),z=min(x,y) etc.).

In some situations, this problem can be easily fixed by marking traced values asstatic; in others, it may indicate that your program is doing operations that arenot directly supported by JAX’s JIT compilation model.

Examples:

Traced value used in control flow

One case where this often arises is when a traced value is used inPython control flow. For example:

>>>fromjaximportjit>>>importjax.numpyasjnp>>>@jit...deffunc(x,y):...returnxifx.sum()<y.sum()elsey>>>func(jnp.ones(4),jnp.zeros(4))Traceback (most recent call last):...TracerBoolConversionError:Attempted boolean conversion of JAX Tracer [...]

We could mark both inputsx andy as static, but that would defeatthe purpose of usingjax.jit() here. Another option is to re-expressthe if statement in terms of the three-termjax.numpy.where():

>>>@jit...deffunc(x,y):...returnjnp.where(x.sum()<y.sum(),x,y)>>>func(jnp.ones(4),jnp.zeros(4))Array([0., 0., 0., 0.], dtype=float32)

For more complicated control flow including loops, seeControl flow operators.

Control flow on traced values

Another common cause of this error is if you inadvertently trace over a booleanflag. For example:

>>>@jit...deffunc(x,normalize=True):...ifnormalize:...returnx/x.sum()...returnx>>>func(jnp.arange(5),True)Traceback (most recent call last):...TracerBoolConversionError:Attempted boolean conversion of JAX Tracer ...

Here because the flagnormalize is traced, it cannot be used in Pythoncontrol flow. In this situation, the best solution is probably to mark thisvalue as static:

>>>fromfunctoolsimportpartial>>>@partial(jit,static_argnames=['normalize'])...deffunc(x,normalize=True):...ifnormalize:...returnx/x.sum()...returnx>>>func(jnp.arange(5),True)Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)

For more onstatic_argnums, see the documentation ofjax.jit().

Using non-JAX aware functions

Another common cause of this error is using non-JAX aware functions within JAXcode. For example:

>>>@jit...deffunc(x):...returnmin(x,0)
>>>func(2)Traceback (most recent call last):...TracerBoolConversionError:Attempted boolean conversion of JAX Tracer ...

In this case, the error occurs because Python’s built-inmin function is notcompatible with JAX transforms. This can be fixed by replacing it withjnp.minimum:

>>>@jit...deffunc(x):...returnjnp.minimum(x,0)
>>>print(func(2))0

To understand more subtleties having to do with tracers vs. regular values,and concrete vs. abstract values, you may want to readDifferent kinds of JAX values.

Parameters:

tracer (core.Tracer)

classjax.errors.TracerIntegerConversionError(tracer)#

This error can occur when a JAX Tracer object is used in a context where aPython integer is expected (seeDifferent kinds of JAX values formore on what a Tracer is). It typically occurs in a few situations.

Passing a tracer in place of an integer

This error can occur if you attempt to pass a traced value to a functionthat requires a static integer argument; for example:

>>>fromjaximportjit>>>importnumpyasnp>>>@jit...deffunc(x,axis):...returnnp.split(x,2,axis)>>>func(np.arange(4),0)Traceback (most recent call last):...TracerIntegerConversionError:The __index__() method was called ontraced array with shape int32[0]

When this happens, the solution is often to mark the problematic argument asstatic:

>>>fromfunctoolsimportpartial>>>@partial(jit,static_argnums=1)...deffunc(x,axis):...returnnp.split(x,2,axis)>>>func(np.arange(10),0)[Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]

An alternative is to apply the transformation to a closure that encapsulatesthe arguments to be protected, either manually as below or by usingfunctools.partial():

>>>jit(lambdaarr:np.split(arr,2,0))(np.arange(4))[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]

Note a new closure is created at every invocation, which defeats thecompilation caching mechanism, which is why static_argnums is preferred.

Indexing a list with a Tracer

This error can occur if you attempt to index a Python list with a tracedquantity.For example:

>>>importjax.numpyasjnp>>>fromjaximportjit>>>L=[1,2,3]>>>@jit...deffunc(i):...returnL[i]>>>func(0)Traceback (most recent call last):...TracerIntegerConversionError:The __index__() method was called ontraced array with shape int32[0]

Depending on the context, you can generally fix this either by convertingthe list to a JAX array:

>>>@jit...deffunc(i):...returnjnp.array(L)[i]>>>func(0)Array(1, dtype=int32)

or by declaring the index as a static argument:

>>>fromfunctoolsimportpartial>>>@partial(jit,static_argnums=0)...deffunc(i):...returnL[i]>>>func(0)Array(1, dtype=int32, weak_type=True)

To understand more subtleties having to do with tracers vs. regular values,and concrete vs. abstract values, you may want to readDifferent kinds of JAX values.

Parameters:

tracer (core.Tracer)

classjax.errors.UnexpectedTracerError(msg)#

This error occurs when you use a JAX value that has leaked out of a function.What does it mean to leak a value? If you use a JAX transformation on afunctionf that stores, in some scope outside off, a reference toan intermediate value, that value is considered to have been leaked.Leaking values is a side effect. (Read more about avoiding side effects inPure Functions)

JAX detects leaks when you then use the leaked value in anotheroperation later on, at which point it raises anUnexpectedTracerError.To fix this, avoid side effects: if a function computes a value neededin an outer scope, return that value from the transformed function explicitly.

Specifically, aTracer is JAX’s internal representation of a function’sintermediate values during transformations, e.g. withinjit(),pmap(),vmap(), etc. Encountering aTracer outsideof a transformation implies a leak.

Life-cycle of a leaked value

Consider the following example of a transformed function which leaks a valueto an outer scope:

>>>fromjaximportjit>>>importjax.numpyasjnp>>>outs=[]>>>@jit# 1...defside_effecting(x):...y=x+1# 3...outs.append(y)# 4>>>x=1>>>side_effecting(x)# 2>>>outs[0]+1# 5Traceback (most recent call last):...UnexpectedTracerError:Encountered an unexpected tracer.

In this example we leak a Traced value from an inner transformed scope to anouter scope. We get anUnexpectedTracerError when the leaked value isused, not when the value is leaked.

This example also demonstrates the life-cycle of a leaked value:

  1. A function is transformed (in this case, byjit())

  2. The transformed function is called (initiating an abstract trace of thefunction and turningx into aTracer)

  3. The intermediate valuey, which will later be leaked, is created(an intermediate value of a traced function is also aTracer)

  4. The value is leaked (appended to a list in an outer scope, escapingthe function through a side-channel)

  5. The leaked value is used, and an UnexpectedTracerError is raised.

The UnexpectedTracerError message tries to point to these locations in yourcode by including information about each stage. Respectively:

  1. The name of the transformed function (side_effecting) and whichtransform kicked off the tracejit()).

  2. A reconstructed stack trace of where the leaked Tracer was created,which includes where the transformed function was called.(WhentheTracerwascreated,thefinal5stackframeswere...).

  3. From the reconstructed stack trace, the line of code that createdthe leaked Tracer.

  4. The leak location is not included in the error message because it isdifficult to pin down! JAX can only tell you what the leaked valuelooks like (what shape it has and where it was created) and whatboundary it was leaked over (the name of the transformation and thename of the transformed function).

  5. The current error’s stack trace points to where the value is used.

The error can be fixed by the returning the value out of thetransformed function:

>>>fromjaximportjit>>>importjax.numpyasjnp>>>outs=[]>>>@jit...defnot_side_effecting(x):...y=x+1...returny>>>x=1>>>y=not_side_effecting(x)>>>outs.append(y)>>>outs[0]+1# all good! no longer a leaked value.Array(3, dtype=int32, weak_type=True)
Leak checker

As discussed in point 2 and 3 above, JAX shows a reconstructed stack tracewhich points to where the leaked value was created. This is becauseJAX only raises an error when the leaked value is used, not when thevalue is leaked. This is not the most useful place to raise this error,because you need to know the location where the Tracer was leaked to fix theerror.

To make this location easier to track down, you can use the leak checker.When the leak checker is enabled, an error is raised as soon as aTraceris leaked. (To be more exact, it will raise an error when the transformedfunction from which theTracer is leaked returns)

To enable the leak checker you can use theJAX_CHECK_TRACER_LEAKSenvironment variable or thewithjax.checking_leaks() context manager.

Note

Note that this tool is experimental and may report false positives. Itworks by disabling some JAX caches, so it will have a negative effect onperformance and should only be used when debugging.

Example usage:

>>>fromjaximportjit>>>importjax.numpyasjnp>>>outs=[]>>>@jit...defside_effecting(x):...y=x+1...outs.append(y)>>>x=1>>>withjax.checking_leaks():...y=side_effecting(x)Traceback (most recent call last):...Exception:Leaked Trace
Parameters:

msg (str)


[8]ページ先頭

©2009-2025 Movatter.jp