Tracing
Contents
Tracing#
jax.jit and other JAX transforms work bytracing a function to determine its effect on inputs of a specific shape and type. For a window into tracing, let’s put a fewprint() statements within a JIT-compiled function and then call the function:
fromjaximportjitimportjax.numpyasjnpimportnumpyasnp@jitdeff(x,y):print("Running f():")print(f" x ={x}")print(f" y ={y}")result=jnp.dot(x+1,y+1)print(f" result ={result}")returnresultx=np.random.randn(3,4)y=np.random.randn(4)f(x,y)
Running f(): x = JitTracer<float32[3,4]> y = JitTracer<float32[4]> result = JitTracer<float32[3]>
Array([1.745958 , 1.0156265, 1.8004583], dtype=float32)
Notice that the print statements execute, but rather than printing the data wepassed to the function, though, it printstracer objects that stand-in forthem (something likeTraced<ShapedArray(float32[])>).
These tracer objects are whatjax.jit uses to extract the sequence ofoperations specified by the function. Basic tracers are stand-ins that encodetheshape anddtype of the arrays, but are agnostic to the values. Thisrecorded sequence of computations can then be efficiently applied within XLA tonew inputs with the same shape and dtype, without having to re-execute thePython code.
When we call the compiled function again on matching inputs, no re-compilationis required and nothing is printed because the result is computed in compiledXLA rather than in Python:
x2=np.random.randn(3,4)y2=np.random.randn(4)f(x2,y2)
Array([8.362883 , 0.43910146, 5.0036416 ], dtype=float32)
The extracted sequence of operations is encoded in a JAX expression, orjaxpr for short. You can view the jaxpr using thejax.make_jaxpr transformation:
fromjaximportmake_jaxprdeff(x,y):returnjnp.dot(x+1,y+1)make_jaxpr(f)(x,y)
{lambda; a:f32[3,4] b:f32[4].letc:f32[3,4] = add a 1.0:f32[] d:f32[4] = add b 1.0:f32[] e:f32[3] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] c din(e,) }Note one consequence of this: because JIT compilation is donewithoutinformation on the content of the array, control flow statements in the functioncannot depend on traced values (seeControl flow and logical operators with JIT). For example, this fails:
@jitdeff(x,neg):return-xifnegelsexf(1,True)
---------------------------------------------------------------------------TracerBoolConversionErrorTraceback (most recent call last)CellIn[4],line51@jit2deff(x,neg):3return-xifnegelsex---->5f(1,True)[...skippinghidden13frame]Cell In[4], line 3, inf(x, neg)1@jit2deff(x,neg):---->3return-xifnegelsex[...skippinghidden1frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1859, inconcretization_function_error.<locals>.error(self, arg)1858deferror(self,arg):->1859raiseTracerBoolConversionError(arg)TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Theerroroccurredwhiletracingthefunctionfat/tmp/ipykernel_3441/2422663986.py:1forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentneg.Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
If there are variables that you would not like to be traced, they can be markedasstatic for the purposes of JIT compilation:
fromfunctoolsimportpartial@partial(jit,static_argnums=(1,))deff(x,neg):return-xifnegelsexf(1,True)
Array(-1, dtype=int32, weak_type=True)
Note that calling a JIT-compiled function with a different static argumentresults in re-compilation, so the function still works as expected:
f(1,False)
Array(1, dtype=int32, weak_type=True)
Static vs traced operations#
Just as values can be either static or traced, operations can be static ortraced. Static operations are evaluated at compile-time in Python; tracedoperations are compiled & evaluated at run-time in XLA.
This distinction between static and traced values makes it important to thinkabout how to keep a static value static. Consider this function:
importjax.numpyasjnpfromjaximportjit@jitdeff(x):returnx.reshape(jnp.array(x.shape).prod())x=jnp.ones((2,3))f(x)
---------------------------------------------------------------------------TypeErrorTraceback (most recent call last)CellIn[7],line96returnx.reshape(jnp.array(x.shape).prod())8x=jnp.ones((2,3))---->9f(x)[...skippinghidden13frame]Cell In[7], line 6, inf(x)4@jit5deff(x):---->6returnx.reshape(jnp.array(x.shape).prod())[...skippinghidden2frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:461, in_compute_newshape(arr, newshape)459except:460newshape=[newshape]-->461newshape=core.canonicalize_shape(newshape)# type: ignore[arg-type]462neg1s=[ifori,dinenumerate(newshape)iftype(d)isintandd==-1]463iflen(neg1s)>1:File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:2027, incanonicalize_shape(shape, context)2025exceptTypeError:2026pass->2027raise_invalid_shape_error(shape,context)TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer<int32[]>].Ifusing`jit`,tryusing`static_argnums`orapplying`jit`tosmallersubfunctions.Theerroroccurredwhiletracingthefunctionfat/tmp/ipykernel_3441/1983583872.py:4forjit.ThisvaluebecameatracerduetoJAXoperationsontheselines:operationa:i32[]=reduce_prod[axes=(0,)]bfromline/tmp/ipykernel_3441/1983583872.py:6:19(f)
This fails with an error specifying that a tracer was found instead of a 1Dsequence of concrete values of integer type. Let’s add some print statements tothe function to understand why this is happening:
@jitdeff(x):print(f"x ={x}")print(f"x.shape ={x.shape}")print(f"jnp.array(x.shape).prod() ={jnp.array(x.shape).prod()}")# comment this out to avoid the error:# return x.reshape(jnp.array(x.shape).prod())f(x)
x = JitTracer<float32[2,3]>x.shape = (2, 3)jnp.array(x.shape).prod() = JitTracer<int32[]>
Notice that althoughx is traced,x.shape is a static value. However, whenwe usejnp.array andjnp.prod on this static value, it becomes a tracedvalue, at which point it cannot be used in a function likereshape() thatrequires a static input (recall: array shapes must be static).
A useful pattern is to usenumpy for operations that should be static (i.e.done at compile-time), and usejax.numpy for operations that should be traced(i.e. compiled and executed at run-time). For this function, it might look likethis:
fromjaximportjitimportjax.numpyasjnpimportnumpyasnp@jitdeff(x):returnx.reshape((np.prod(x.shape),))f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)
For this reason, a standard convention in JAX programs is toimportnumpyasnp andimportjax.numpyasjnp so that both interfaces areavailable for finer control over whether operations are performed in a staticmanner (withnumpy, once at compile-time) or a traced manner (withjax.numpy, optimized at run-time).
Understanding which values and operations will be static and which will betraced is a key part of usingjax.jit effectively.
Different kinds of JAX values#
A tracer value carries anabstract value, e.g.,ShapedArray withinformation about the shape and dtype of an array. We will refer here to suchtracers asabstract tracers. Some tracers, e.g., those that are introducedfor arguments of autodiff transformations, carryConcreteArray abstract valuesthat actually include the regular array data, and are used, e.g., for resolvingconditionals. We will refer here to such tracers asconcrete tracers. Tracervalues computed from these concrete tracers, perhaps in combination with regularvalues, result in concrete tracers. Aconcrete value is either a regularvalue or a concrete tracer.
Typically, computations that involve at least a tracer value will produce atracer value. There are very few exceptions, when a computation can beentirely done using the abstract value carried by a tracer, in which case theresult can be aregular Python value. For example, getting the shape of atracer withShapedArray abstract value. Another example is when explicitlycasting a concrete tracer value to a regular type, e.g.,int(x) orx.astype(float). Another such situation is forbool(x), which produces aPython bool when concreteness makes it possible. That case is especially salientbecause of how often it arises in control flow.
Here is how the transformations introduce abstract or concrete tracers:
jax.jit(): introducesabstract tracers for all positional argumentsexcept those denoted bystatic_argnums, which remain regularvalues.jax.pmap(): introducesabstract tracers for all positional argumentsexcept those denoted bystatic_broadcasted_argnums.jax.vmap(),jax.make_jaxpr(),xla_computation():introduceabstract tracers for all positional arguments.jax.jvp()andjax.grad()introduceconcrete tracersfor all positional arguments. An exception is when these transformationsare within an outer transformation and the actual arguments arethemselves abstract tracers; in that case, the tracers introducedby the autodiff transformations are also abstract tracers.All higher-order control-flow primitives (
lax.cond(),lax.while_loop(),lax.fori_loop(),lax.scan()) when theyprocess the functionals introduceabstract tracers, whether or not thereis a JAX transformation in progress.
All of this is relevant when you have code that can operateonly on regular Python values, such as code that has conditionalcontrol-flow based on data:
defdivide(x,y):returnx/yify>=1.else0.
If we want to applyjax.jit(), we must ensure to specifystatic_argnums=1to ensurey stays a regular value. This is due to the boolean expressiony>=1., which requires concrete values (regular or tracers). Thesame would happen if we write explicitlybool(y>=1.), orint(y),orfloat(y).
Interestingly,jax.grad(divide)(3.,2.), works becausejax.grad()uses concrete tracers, and resolves the conditional using the concretevalue ofy.
