Concurrency
Concurrency#
JAX has limited support for Python concurrency.
Clients may call JAX APIs (e.g.,jit() orgrad())concurrently from separate Python threads.
It is not permitted to manipulate JAX trace values concurrently from multiplethreads. In other words, while it is permissible to call functions that use JAXtracing (e.g.,jit()) from multiple threads, you must not usethreading to manipulate JAX values inside the implementation of the functionf that is passed tojit(). The most likely outcome if you do thisis a mysterious error from JAX.
In multi-controller JAX, different processes must apply the same JAX operationsin the same order on a given device. If you are using threads withmulti-controller JAX, you can use thethread_guard() context managerto detect cases where threads may schedule operations in different orders indifferent processes, leading to non-deterministic crashes. When the thread guardis set, an error will be raised at runtime if a JAX operation is called from athread other than the one in which the thread guard was set.
