Pallas: a JAX kernel language
Pallas: a JAX kernel language#
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.It aims to provide fine-grained control over the generated code, combined withthe high-level ergonomics of JAX tracing and thejax.numpy API.
This section contains tutorials, guides and examples for using Pallas.See also thejax.experimental.pallas module API documentation.
Warning
Pallas is experimental and is changing frequently.See thePallas Changelog for the recent changes.
You can expect to encounter errors and unimplemented cases, e.g., whenlowering of high-level JAX concepts that would require emulation,or simply because Pallas is still under development.
Guides
TPU backend guide
Mosaic GPU backend guide
Other
