jax.experimental.pallas module
Contents
jax.experimental.pallas module#
Module for Pallas, a JAX extension for custom kernels.
See the Pallas documentation athttps://docs.jax.dev/en/latest/pallas/index.html.
Backends#
Classes#
| Specifies how an array should be sliced for each invocation of a kernel. |
| Encodes the grid parameters for |
| A slice with a start index and a size. |
Functions#
| Runs a function on a mesh, mapping it over the devices in the mesh. |
| Entry point for creating a Pallas kernel. |
| Entry point for creating a Pallas kernel. |
| Returns the kernel execution position along the given axis of the grid. |
| Returns the size of the grid along the given axis. |
| Computes the ceiling division of a divided by b. |
| Constructs a |
| Create an empty array of possibly uninitialized values. |
| Create an empty PyTree of possibly uninitialized values. |
| Broadcasts an array to a new shape. |
| Check the condition if |
| Prints values from inside a Pallas kernel. |
| Computes the dot product of two arrays. |
| Returns a global reference that persists across all kernel invocations. |
| Returns a decorator that calls the decorated function in a loop. |
| A compiler hint that asserts a value is a static multiple of another. |
| Calls the function with allocated references and returns the result. |
| Calls the decorated function when the condition is met. |
| Returns a function decorator that runs a function with provided allocations. |
Synchronization#
| Reads the value of a semaphore. |
| Increments the value of a semaphore. |
| Blocks execution of the current thread until a semaphore reaches a value. |
