XLA compiler flags
Contents
XLA compiler flags#
Introduction#
This guide gives a brief overview of XLA and how XLA relates to Jax.For in-depth details please refer toXLA documentation.
XLA: The Powerhouse Behind Jax#
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax’s performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions.
Jax uses XLA’s JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime.
Configuring XLA in Jax:#
You can influence XLA’s behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook.
For the colab notebooks:
Provide flags usingos.environ['XLA_FLAGS']:
importos# Set multiple flags separated by spacesos.environ['XLA_FLAGS']='--flag1=value1 --flag2=value2'
For the python scripts:
SpecifyXLA_FLAGS as a part of cli command:
XLA_FLAGS='--flag1=value1 --flag2=value2'python3source.py
Important Notes:
Set
XLA_FLAGSbefore importing Jax or other relevant libraries. ChangingXLA_FLAGSafter backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to setXLA_FLAGSbefore executing any Jax code.Experiment with different flags to optimize performance for your specific use case.
For further information:
Complete and up to date documentation about XLA can be found in the officialXLA documentation.
For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values inxla/debug_options_flags.cc, and a complete list of flags could be foundhere.
A guide on how to use key XLA flags can be foundhere.
Additional reading:
