Frequently asked questions (FAQ)
Contents
Frequently asked questions (FAQ)#
We are collecting answers to frequently asked questions here.Contributions welcome!
jit changes the behavior of my function#
If you have a Python function that changes behavior after usingjax.jit(), perhapsyour function uses global state, or has side-effects. In the following code, theimpure_func uses the globaly and has a side-effect due toprint:
y=0# @jit # Different behavior with jitdefimpure_func(x):print("Inside:",y)returnx+yforyinrange(3):print("Result:",impure_func(y))
Withoutjit the output is:
Inside:0Result:0Inside:1Result:2Inside:2Result:4
and withjit it is:
Inside:0Result:0Result:1Result:2
Forjax.jit(), the function is executed once using the Python interpreter, at which time theInside printing happens, and the first value ofy is observed. Then, the functionis compiled and cached, and executed multiple times with different values ofx, butwith the same first value ofy.
Additional reading:
jit changes the exact numerics of outputs#
Sometimes users are surprised by the fact that wrapping a function withjit()can change the function’s outputs. For example:
>>>fromjaximportjit>>>importjax.numpyasjnp>>>deff(x):...returnjnp.log(jnp.sqrt(x))>>>x=jnp.pi>>>print(f(x))0.572365
>>>print(jit(f)(x))0.5723649
This slight difference in output comes from optimizations within the XLA compiler:during compilation, XLA will sometimes rearrange or elide certain operations to makethe overall computation more efficient.
In this case, XLA utilizes the properties of the logarithm to replacelog(sqrt(x))with0.5*log(x), which is a mathematically identical expression that can becomputed more efficiently than the original. The difference in output comes fromthe fact that floating point arithmetic is only a close approximation of real math,so different ways of computing the same expression may have subtly different results.
Other times, XLA’s optimizations may lead to even more drastic differences.Consider the following example:
>>>deff(x):...returnjnp.log(jnp.exp(x))>>>x=100.0>>>print(f(x))inf
>>>print(jit(f)(x))100.0
In non-JIT-compiled op-by-op mode, the result isinf becausejnp.exp(x)overflows and returnsinf. Under JIT, however, XLA recognizes thatlog isthe inverse ofexp, and removes the operations from the compiled function,simply returning the input. In this case, JIT compilation produces a more accuratefloating point approximation of the real result.
Unfortunately the full list of XLA’s algebraic simplifications is not welldocumented, but if you’re familiar with C++ and curious about what types ofoptimizations the XLA compiler makes, you can see them in the source code:algebraic_simplifier.cc.
jit decorated function is very slow to compile#
If yourjit decorated function takes tens of seconds (or more!) to run thefirst time you call it, but executes quickly when called again, JAX is taking along time to trace or compile your code.
This is usually a sign that calling your function generates a large amount ofcode in JAX’s internal representation, typically because it makes heavy use ofPython control flow such asfor loops. For a handful of loop iterations,Python is OK, but if you needmany loop iterations, you should rewrite yourcode to make use of JAX’sstructured control flow primitives(such aslax.scan()) or avoid wrapping the loop withjit (you canstill usejit decorated functionsinside the loop).
If you’re not sure if this is the problem, you can try runningjax.make_jaxpr() on your function. You can expect slow compilation if theoutput is many hundreds or thousands of lines long.
Sometimes it isn’t obvious how to rewrite your code to avoid Python loopsbecause your code makes use of many arrays with different shapes. Therecommended solution in this case is to make use of functions likejax.numpy.where() to do your computation on padded arrays with fixedshape.
If your functions are slow to compile for another reason, please open an issueon GitHub.
How to usejit with methods?#
Most examples ofjax.jit() concern decorating stand-alone Python functions,but decorating a method within a class introduces some complication. For example,consider the following simple class, where we’ve used a standardjit()annotation on a method:
>>>importjax.numpyasjnp>>>fromjaximportjit>>>classCustomClass:...def__init__(self,x:jnp.ndarray,mul:bool):...self.x=x...self.mul=mul......@jit# <---- How to do this correctly?...defcalc(self,y):...ifself.mul:...returnself.x*y...returny
However, this approach will result in an error when you attempt to call this method:
>>>c=CustomClass(2,True)>>>c.calc(3)---------------------------------------------------------------------------TypeError Traceback (most recent call last) File"<stdin>", line1, in<moduleTypeError:Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
The problem is that the first argument to the function isself, which has typeCustomClass, and JAX does not know how to handle this type.There are three basic strategies we might use in this case, and we’ll discussthem below.
Strategy 1: JIT-compiled helper function#
The most straightforward approach is to create a helper function external to the classthat can be JIT-decorated in the normal way. For example:
>>>fromfunctoolsimportpartial>>>classCustomClass:...def__init__(self,x:jnp.ndarray,mul:bool):...self.x=x...self.mul=mul......defcalc(self,y):...return_calc(self.mul,self.x,y)>>>@partial(jit,static_argnums=0)...def_calc(mul,x,y):...ifmul:...returnx*y...returny
The result will work as expected:
>>>c=CustomClass(2,True)>>>print(c.calc(3))6
The benefit of such an approach is that it is simple, explicit, and it avoids the needto teach JAX how to handle objects of typeCustomClass. However, you may wish tokeep all the method logic in the same place.
Strategy 2: Markingself as static#
Another common pattern is to usestatic_argnums to mark theself argument as static.But this must be done with care to avoid unexpected results.You may be tempted to simply do this:
>>>classCustomClass:...def__init__(self,x:jnp.ndarray,mul:bool):...self.x=x...self.mul=mul......# WARNING: this example is broken, as we'll see below. Don't copy & paste!...@partial(jit,static_argnums=0)...defcalc(self,y):...ifself.mul:...returnself.x*y...returny
If you call the method, it will no longer raise an error:
>>>c=CustomClass(2,True)>>>print(c.calc(3))6
However, there is a catch: if you mutate the object after the first method call, thesubsequent method call may return an incorrect result:
>>>c.mul=False>>>print(c.calc(3))# Should print 36
Why is this? When you mark an object as static, it will effectively be used as a dictionarykey in JIT’s internal compilation cache, meaning its hash (i.e.hash(obj)) equality(i.e.obj1==obj2) and object identity (i.e.obj1isobj2) will be assumed to haveconsistent behavior. The default__hash__ for a custom object is its object ID, and soJAX has no way of knowing that a mutated object should trigger a re-compilation.
You can partially address this by defining an appropriate__hash__ and__eq__ methodsfor your object; for example:
>>>classCustomClass:...def__init__(self,x:jnp.ndarray,mul:bool):...self.x=x...self.mul=mul......@partial(jit,static_argnums=0)...defcalc(self,y):...ifself.mul:...returnself.x*y...returny......def__hash__(self):...returnhash((self.x,self.mul))......def__eq__(self,other):...return(isinstance(other,CustomClass)and...(self.x,self.mul)==(other.x,other.mul))
(see theobject.__hash__() documentation for more discussion of the requirementswhen overriding__hash__).
This should work correctly with JIT and other transformsso long as you never mutateyour object. Mutations of objects used as hash keys lead to several subtle problems,which is why for example mutable Python containers (e.g.dict,list)don’t define__hash__, while their immutable counterparts (e.g.tuple) do.
If your class relies on in-place mutations (such as settingself.attr=... within itsmethods), then your object is not really “static” and marking it as such may lead to problems.Fortunately, there’s another option for this case.
Strategy 3: MakingCustomClass a PyTree#
The most flexible approach to correctly JIT-compiling a class method is to register thetype as a custom PyTree object; seeCustom pytree nodes. This lets you specifyexactly which components of the class should be treated as static and which should betreated as dynamic. Here’s how it might look:
>>>classCustomClass:...def__init__(self,x:jnp.ndarray,mul:bool):...self.x=x...self.mul=mul......@jit...defcalc(self,y):...ifself.mul:...returnself.x*y...returny......def_tree_flatten(self):...children=(self.x,)# arrays / dynamic values...aux_data={'mul':self.mul}# static values...return(children,aux_data)......@classmethod...def_tree_unflatten(cls,aux_data,children):...returncls(*children,**aux_data)>>>fromjaximporttree_util>>>tree_util.register_pytree_node(CustomClass,...CustomClass._tree_flatten,...CustomClass._tree_unflatten)
This is certainly more involved, but it solves all the issues associated with the simplerapproaches used above:
>>>c=CustomClass(2,True)>>>print(c.calc(3))6>>>c.mul=False# mutation is detected>>>print(c.calc(3))3>>>c=CustomClass(jnp.array(2),True)# non-hashable x is supported>>>print(c.calc(3))6
So long as yourtree_flatten andtree_unflatten functions correctly handle allrelevant attributes in the class, you should be able to use objects of this type directlyas arguments to JIT-compiled functions, without any special annotations.
Is JAX faster than NumPy?#
One question users frequently attempt to answer with such benchmarks is whether JAXis faster than NumPy; due to the difference in the two packages, there is not asimple answer.
Broadly speaking:
NumPy operations are executed eagerly, synchronously, and only on CPU.
JAX operations may be executed eagerly or after compilation (if inside
jit());they are dispatched asynchronously (seeAsynchronous dispatch); and they canbe executed on CPU, GPU, or TPU, each of which have vastly different and continuouslyevolving performance characteristics.
These architectural differences make meaningful direct benchmark comparisons betweenNumPy and JAX difficult.
Additionally, these differences have led to different engineering focus between thepackages: for example, NumPy has put significant effort into decreasing the per-calldispatch overhead for individual array operations, because in NumPy’s computationalmodel that overhead cannot be avoided.JAX, on the other hand, has several ways to avoid dispatch overhead (e.g. JITcompilation, asynchronous dispatch, batching transforms, etc.), and so reducingper-call overhead has been less of a priority.
Keeping all that in mind, in summary: if you’re doing microbenchmarks of individualarray operations on CPU, you can generally expect NumPy to outperform JAX due to itslower per-operation dispatch overhead. If you’re running your code on GPU or TPU,or are benchmarking more complicated JIT-compiled sequences of operations on CPU, youcan generally expect JAX to outperform NumPy.
Gradients containNaN where usingwhere#
If you define a function usingwhere to avoid an undefined value, if youare not careful you may obtain aNaN for reverse differentiation:
defmy_log(x):returnjnp.where(x>0.,jnp.log(x),0.)my_log(0.)==>0.# Okjax.grad(my_log)(0.)==>NaN
A short explanation is that duringgrad computation the adjoint correspondingto the undefinedjnp.log(x) is aNaN and it gets accumulated to theadjoint of thejnp.where. The correct way to write such functions is to ensurethat there is ajnp.whereinside the partially-defined function, to ensurethat the adjoint is always finite:
defsafe_for_grad_log(x):returnjnp.log(jnp.where(x>0.,x,1.))safe_for_grad_log(0.)==>0.# Okjax.grad(safe_for_grad_log)(0.)==>0.# Ok
The innerjnp.where may be needed in addition to the original one, e.g.:
defmy_log_or_y(x,y):"""Return log(x) if x > 0 or y"""returnjnp.where(x>0.,jnp.log(jnp.where(x>0.,x,1.)),y)
Additional reading:
Why are gradients zero for functions based on sort order?#
If you define a function that processes the input using operations that depend onthe relative ordering of inputs (e.g.max,greater,argsort, etc.) thenyou may be surprised to find that the gradient is everywhere zero.Here is an example, where we definef(x) to be a step function that returns0 whenx is negative, and1 whenx is positive:
importjaximportnumpyasnpimportjax.numpyasjnpdeff(x):return(x>0).astype(float)df=jax.vmap(jax.grad(f))x=jnp.array([-1.0,-0.5,0.0,0.5,1.0])print(f"f(x) ={f(x)}")# f(x) = [0. 0. 0. 1. 1.]print(f"df(x) ={df(x)}")# df(x) = [0. 0. 0. 0. 0.]
The fact that the gradient is everywhere zero may be confusing at first glance:after all, the output does change in response to the input, so how can the gradientbe zero? However, zero turns out to be the correct result in this case.
Why is this? Remember that what differentiation is measuring the change infgiven an infinitesimal change inx. Forx=1.0,f returns1.0.If we perturbx to make it slightly larger or smaller, this does not changethe output, so by definition,grad(f)(1.0) should be zero.This same logic holds for all values off greater than zero: infinitesimallyperturbing the input does not change the output, so the gradient is zero.Similarly, for all values ofx less than zero, the output is zero.Perturbingx does not change this output, so the gradient is zero.That leaves us with the tricky case ofx=0. Surely, if you perturbx upward,it will change the output, but this is problematic: an infinitesimal change inxproduces a finite change in the function value, which implies the gradient isundefined.Fortunately, there’s another way for us to measure the gradient in this case: weperturb the function downward, in which case the output does not change, and so thegradient is zero.JAX and other autodiff systems tend to handle discontinuities in this way: if thepositive gradient and negative gradient disagree, but one is defined and the other isnot, we use the one that is defined.Under this definition of the gradient, mathematically and numerically the gradient ofthis function is everywhere zero.
The problem stems from the fact that our function has a discontinuity atx=0.Ourf here is essentially aHeaviside Step Function, and we can use aSigmoid Function as a smoothed replacement.The sigmoid is approximately equal to the heaviside function whenx is far from zero,but replaces the discontinuity atx=0 with a smooth, differentiable curve.As a result of usingjax.nn.sigmoid(), we get a similar computation withwell-defined gradients:
defg(x):returnjax.nn.sigmoid(x)dg=jax.vmap(jax.grad(g))x=jnp.array([-10.0,-1.0,0.0,1.0,10.0])withnp.printoptions(suppress=True,precision=2):print(f"g(x) ={g(x)}")# g(x) = [0. 0.27 0.5 0.73 1. ]print(f"dg(x) ={dg(x)}")# dg(x) = [0. 0.2 0.25 0.2 0. ]
Thejax.nn submodule also has smooth versions of other common rank-basedfunctions, for examplejax.nn.softmax() can replace uses ofjax.numpy.argmax(),jax.nn.soft_sign() can replace uses ofjax.numpy.sign(),jax.nn.softplus() orjax.nn.squareplus()can replace uses ofjax.nn.relu(), etc.
How can I convert a JAX Tracer to a NumPy array?#
When inspecting a transformed JAX function at runtime, you’ll find that arrayvalues are replaced byjax.core.Tracer objects:
@jax.jitdeff(x):print(type(x))returnxf(jnp.arange(5))
This prints the following:
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
A frequent question is how such a tracer can be converted back to a normal NumPyarray. In short,it is impossible to convert a Tracer to a NumPy array, becausea tracer is an abstract representation ofevery possible value with a given shapeand dtype, while a numpy array is a concrete member of that abstract class.For more discussion of how tracers work within the context of JAX transformations,seeJIT mechanics.
The question of converting Tracers back to arrays usually comes up withinthe context of another goal, related to accessing intermediate values in acomputation at runtime. For example:
If you wish to print a traced value at runtime for debugging purposes, you mightconsider using
jax.debug.print().If you wish to call non-JAX code within a transformed JAX function, you mightconsider using
jax.pure_callback(), an example of which is available atPure callback example.If you wish to input or output array buffers at runtime (for example, load datafrom file, or log the contents of the array to disk), you might consider using
jax.experimental.io_callback(), an example of which can be found atIO callback example.
For more information on runtime callbacks and examples of their use,seeExternal callbacks in JAX.
Why do some CUDA libraries fail to load/initialize?#
When resolving dynamic libraries, JAX uses the usualdynamic linker search pattern.JAX setsRPATH to point to the JAX-relative location of thepip-installed NVIDIA CUDA packages, preferring them if installed. Ifld.socannot find your CUDA runtime libraries along its usual search path, then youmust include the paths to those libraries explicitly inLD_LIBRARY_PATH.The easiest way to ensure your CUDA files are discoverable is to simply installthenvidia-*-cu12 pip packages, which are included in the standardjax[cuda_12] install option.
Occasionally, even when you have ensured that your runtime libraries are discoverable,there may still be some issues with loading or initializing them. A common cause ofsuch issues is simply having insufficient memory for CUDA library initialization atruntime. This sometimes occurs because JAX will pre-allocate too large of a chunk ofcurrently available device memory for faster execution, occasionally resulting ininsufficient memory being left available for runtime CUDA library initialization.
This is especially likely when running multiple JAX instances, running JAX intandem with TensorFlow which performs its own pre-allocation, or when runningJAX on a system where the GPU is being heavily utilized by other processes. Whenin doubt, try running the program again with reduced pre-allocation, either byreducingXLA_PYTHON_CLIENT_MEM_FRACTION from the default of.75,or settingXLA_PYTHON_CLIENT_PREALLOCATE=false. For more details, pleasesee the page onJAX GPU memory allocation.
Controlling data and computation placement on devices#
Moved toControlling data and computation placement on devices.
Benchmarking JAX code#
Moved toBenchmarking JAX code.
Buffer donation#
Moved toBuffer donation.
