Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Control flow and logical operators with JIT#

When executing eagerly (outside ofjit), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators withjit is more complicated.

In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through thecontrol flow graph (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype.

fromjaximportgrad,jitimportjax.numpyasjnp

For example, this works:

@jitdeff(x):foriinrange(3):x=2*xreturnxprint(f(3))
24

So does this:

@jitdefg(x):y=0.foriinrange(x.shape[0]):y=y+x[i]returnyprint(g(jnp.array([1.,2.,3.])))
6.0

But this doesn’t, at least by default:

@jitdeff(x):ifx<3:return3.*x**2else:return-4*x# This will fail!f(2)
---------------------------------------------------------------------------TracerBoolConversionErrorTraceback (most recent call last)CellIn[4],line96return-4*x8# This will fail!---->9f(2)[...skippinghidden13frame]Cell In[4], line 3, inf(x)1@jit2deff(x):---->3ifx<3:4return3.*x**25else:[...skippinghidden1frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1859, inconcretization_function_error.<locals>.error(self, arg)1858deferror(self,arg):->1859raiseTracerBoolConversionError(arg)TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Theerroroccurredwhiletracingthefunctionfat/tmp/ipykernel_1567/3402096563.py:1forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentx.Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

Neither does this:

@jitdefg(x):return(x>0)and(x<3)# This will fail!g(2)
---------------------------------------------------------------------------TracerBoolConversionErrorTraceback (most recent call last)CellIn[5],line63return(x>0)and(x<3)5# This will fail!---->6g(2)[...skippinghidden13frame]Cell In[5], line 3, ing(x)1@jit2defg(x):---->3return(x>0)and(x<3)[...skippinghidden1frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1859, inconcretization_function_error.<locals>.error(self, arg)1858deferror(self,arg):->1859raiseTracerBoolConversionError(arg)TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Theerroroccurredwhiletracingthefunctiongat/tmp/ipykernel_1567/543860509.py:1forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentx.Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

What gives!?

When wejit-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.

For example, if we evaluate an@jit function on the arrayjnp.array([1.,2.,3.],jnp.float32), we might want to compile code that we can reuse to evaluate the function onjnp.array([4.,5.,6.],jnp.float32) to save on compile time.

To get a view of your Python code that is valid for many different argument values, JAX traces it with theShapedArray abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract valueShapedArray((3,),jnp.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

But there’s a tradeoff here: if we trace a Python function on aShapedArray((),jnp.float32) that isn’t committed to a specific concrete value, when we hit a line likeifx<3, the expressionx<3 evaluates to an abstractShapedArray((),jnp.bool_) that represents the set{True,False}. When Python attempts to coerce that to a concreteTrue orFalse, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

The good news is that you can control this tradeoff yourself. By havingjit trace on more refined abstract values, you can relax the traceability constraints. For example, using thestatic_argnames (orstatic_argnums) argument tojit, we can specify to trace on concrete values of some arguments. Here’s that example function again:

deff(x):ifx<3:return3.*x**2else:return-4*xf=jit(f,static_argnames='x')print(f(2.))
12.0

Here’s another example, this time involving a loop:

deff(x,n):y=0.foriinrange(n):y=y+x[i]returnyf=jit(f,static_argnames='n')f(jnp.array([2.,3.,4.]),2)
Array(5., dtype=float32)

In effect, the loop gets statically unrolled. JAX can also trace athigher levels of abstraction, likeUnshaped, but that’s not currently the default for any transformation

️⚠️functions with argument-value dependent shapes

These control-flow issues also come up in a more subtle way: numerical functions we want tojit can’t specialize the shapes of internal arrays on argumentvalues (specializing on argumentshapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variablelength.

defexample_fun(length,val):returnjnp.ones((length,))*val# un-jit'd works fineprint(example_fun(5,4))
[4. 4. 4. 4. 4.]
bad_example_jit=jit(example_fun)# this will fail:bad_example_jit(10,4)
---------------------------------------------------------------------------TypeErrorTraceback (most recent call last)CellIn[9],line31bad_example_jit=jit(example_fun)2# this will fail:---->3bad_example_jit(10,4)[...skippinghidden13frame]Cell In[8], line 2, inexample_fun(length, val)1defexample_fun(length,val):---->2returnjnp.ones((length,))*valFile ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:138, inones(shape, dtype, device, out_sharding)136raiseTypeError("expected sequence object with len >= 0 or a single integer")137if(m:=_check_forgot_shape_tuple("ones",shape,dtype)):raiseTypeError(m)-->138shape=canonicalize_shape(shape)139dtype=dtypes.check_and_canonicalize_user_dtype(140floatifdtypeisNoneelsedtype,"ones")141sharding=util.choose_device_or_out_sharding(142device,out_sharding,'jnp.ones')File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:45, incanonicalize_shape(shape, context)43returncore.canonicalize_shape((shape,),context)44else:--->45returncore.canonicalize_shape(shape,context)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:2027, incanonicalize_shape(shape, context)2025exceptTypeError:2026pass->2027raise_invalid_shape_error(shape,context)TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).Ifusing`jit`,tryusing`static_argnums`orapplying`jit`tosmallersubfunctions.Theerroroccurredwhiletracingthefunctionexample_funat/tmp/ipykernel_1567/1210496444.py:1forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentlength.
# static_argnames tells JAX to recompile on changes at these argument positions:good_example_jit=jit(example_fun,static_argnames='length')# first compileprint(good_example_jit(10,4))# recompilesprint(good_example_jit(5,4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.][4. 4. 4. 4. 4.]

static_argnames can be handy iflength in our example rarely changes, but it would be disastrous if it changed a lot!

Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays insidejit’d functions:

@jitdeff(x):print(x)y=2*xprint(y)returnyf(2)
JitTracer<~int32[]>JitTracer<~int32[]>
Array(4, dtype=int32, weak_type=True)

Structured control flow primitives#

There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:

  • lax.conddifferentiable

  • lax.while_loopfwd-mode-differentiable

  • lax.fori_loopfwd-mode-differentiable in general;fwd and rev-mode differentiable if endpoints are static.

  • lax.scandifferentiable

cond#

python equivalent:

defcond(pred,true_fun,false_fun,operand):ifpred:returntrue_fun(operand)else:returnfalse_fun(operand)
fromjaximportlaxoperand=jnp.array([0.])lax.cond(True,lambdax:x+1,lambdax:x-1,operand)# --> array([1.], dtype=float32)lax.cond(False,lambdax:x+1,lambdax:x-1,operand)# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)

jax.lax provides two other functions that allow branching on dynamic predicates:

  • lax.select islike a batched version oflax.cond, with the choices expressed as pre-computed arraysrather than as functions.

  • lax.switch islikelax.cond, but allows switching between any number of callable choices.

In addition,jax.numpy provides several numpy-style interfaces to these functions:

  • jnp.where withthree arguments is the numpy-style wrapper oflax.select.

  • jnp.piecewiseis a numpy-style wrapper oflax.switch, but switches on a list of boolean conditions rather than a single scalar index.

  • jnp.select hasan API similar tojnp.piecewise, but the choices are given as pre-computed arrays ratherthan as functions. It is implemented in terms of multiple calls tolax.select.

while_loop#

python equivalent:

defwhile_loop(cond_fun,body_fun,init_val):val=init_valwhilecond_fun(val):val=body_fun(val)returnval
init_val=0cond_fun=lambdax:x<10body_fun=lambdax:x+1lax.while_loop(cond_fun,body_fun,init_val)# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)

fori_loop#

python equivalent:

deffori_loop(start,stop,body_fun,init_val):val=init_valforiinrange(start,stop):val=body_fun(i,val)returnval
init_val=0start=0stop=10body_fun=lambdai,x:x+ilax.fori_loop(start,stop,body_fun,init_val)# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)

Summary#

\[\begin{split}\begin{array} {r|rr}\hline \\textrm{construct}& \textrm{jit}& \textrm{grad} \\\hline \\textrm{if} & ❌ & ✔ \\\textrm{for} & ✔* & ✔\\\textrm{while} & ✔* & ✔\\\textrm{lax.cond} & ✔ & ✔\\\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\\textrm{lax.scan} & ✔ & ✔\\\hline\end{array}\end{split}\]

\(\ast\) = argument-value-independent loop condition - unrolls the loop

Logical operators#

jax.numpy provideslogical_and,logical_or, andlogical_not, which operate element-wise on arrays and can be evaluated underjit without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (&,|,~) can also be used withjit.

For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar.

defpython_check_positive_even(x):is_even=x%2==0# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.returnis_evenand(x>0)@jitdefjax_check_positive_even(x):is_even=x%2==0# `logical_and` does not short circuit, so `x > 0` is always evaluated.returnjnp.logical_and(is_even,x>0)print(python_check_positive_even(24))print(jax_check_positive_even(24))
TrueTrue

When the JAX version withlogical_and is applied to an array, it returns elementwise values.

x=jnp.array([-1,2,5])print(jax_check_positive_even(x))
[False  True False]

Python logical operators error when applied to JAX arrays of more than one element, even withoutjit. This replicates NumPy’s behavior.

print(python_check_positive_even(x))
---------------------------------------------------------------------------ValueErrorTraceback (most recent call last)CellIn[17],line1---->1print(python_check_positive_even(x))Cell In[15], line 4, inpython_check_positive_even(x)2is_even=x%2==03# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.---->4returnis_evenand(x>0)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/array.py:297, inArrayImpl.__bool__(self)296def__bool__(self):-->297core.check_bool_conversion(self)298returnbool(self._value)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:883, incheck_bool_conversion(arr)880raiseValueError("The truth value of an empty array is ambiguous. Use"881" `array.size > 0` to check that an array is not empty.")882ifarr.size>1:-->883raiseValueError("The truth value of an array with more than one element"884" is ambiguous. Use a.any() or a.all()")ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Python control flow + autodiff#

Remember that the above constraints on control flow and logical operators are relevant only withjit. If you just want to applygrad to your python functions, withoutjit, you can use regular Python control-flow constructs with no problems, as if you were usingAutograd (or Pytorch or TF Eager).

deff(x):ifx<3:return3.*x**2else:return-4*xprint(grad(f)(2.))# ok!print(grad(f)(4.))# ok!
12.0-4.0

[8]ページ先頭

©2009-2025 Movatter.jp