jax.experimental.pallas.tpu module
Contents
jax.experimental.pallas.tpu module#
Mosaic-specific Pallas APIs.
Classes#
| |
| Mosaic TPU compiler parameters. |
| |
| |
| |
| |
| TPU hardware information. |
Functions#
| Loads an array from the given ref. |
| Stores a value to the given ref. |
Communication#
| Issues a DMA copying from src_ref to dst_ref. |
| Issues a remote DMA copying from src_ref to dst_ref. |
| Creates a description of an asynchronous copy operation. |
| Creates a description of a remote copy operation. |
| Synchronously copies a PyTree of refs to another PyTree of refs. |
Pipelining#
| A helper class to automate VMEM double buffering in pallas pipelines. |
Abstract interface for BufferedRefs. | |
| Creates a function to emit a manual pallas pipeline. |
| Creates pallas pipeline and top-level allocation preparation functions. |
| Retrieve a named pipeline schedule or pass through fully specified one. |
| Create BufferedRefs for the pipeline. |
Pseudorandom Number Generation#
| Sets the seed for PRNG. |
| Samples a block of random values with invariance guarantees. |
| Sample Bernoulli random values with given shape and mean. |
| Sample uniform bits in the form of unsigned integers. |
| Sample standard normal random values with given shape and float dtype. |
| Sample uniform random values in [minval, maxval) with given shape/dtype. |
| Helper function for converting non-Pallas PRNG keys into Pallas keys. |
Interpret Mode#
| Context manager that forces TPU interpret mode under its dynamic context. |
| Parameters for TPU interpret mode. |
Resets all global, shared state used by TPU interpret mode. | |
|
Miscellaneous#
| Synchronizes all cores in a given axis. |
Returns a barrier semaphore. | |
Returns the TPU hardware information for the current device. | |
Returns whether the current device is a TPU. | |
| Runs a function on the first core in a given axis. |
| Constrains the memory space of an array. |
