Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Writing TPU kernels with Pallas#

This page focuses on the details that are important when attempting to runPallas kernels on Google TPUs. For one, the TPU backend is still in anexperimental phase, and only a subset of JAX NumPy will be accepted.Furthermore, writing performant code for TPUs might require thinking carefullyabout the native capabilities of the hardware. While many patterns that areunnatural to the hardware will be accepted, they might end up requiringsoftware emulation, and can slow down the computation.

Warning

This feature should still be considered experimental as work is still inprogress (in particular on improving the error messages).

Note

While all the features described here are experimental, we remain very seriousabout maintaining their correctness. As such, it might not be uncommon tosee a “not implemented” error while attempting to write TPU kernels. But, ifa kernel is accepted by the compiler, itmust return the expected results.

If you see unexpected outputs, please compare them against a kernel run withinterpret=True passed in topallas_call. If the results diverge,please file abug report.

What is a TPU?#

A TPUv4 board

TPU is a hardware accelerator developed at Google. You can think of TPUs asGPUs, but specialized for machine learning workloads specifically. As such,their architecture differs quite significantly. However, we believe that Pallascan make it easy to start writing TPU kernels, even without having a fullunderstanding of the underlying hardware. Having said that, understanding thehardware well will certainly make it easier to write performant kernels.

In a nutshell, the main difference between TPUs and GPUs is that TPUs aresequential machines with a very wide vector register (kind of like a CPU!).At the same time, they allow the software to schedule certain operations in thebackground, making them execute asynchronously with respect to the maininstruction stream. This includes things like HBM memory accesses(which cannot be issued directly, but instead have to be prefetched tolower levels of the memory hierarchy by the DMA subunits), matrix multiplies(supported by the MXU unit) or matrix transpositions and permutes (supported bythe XLU unit).

If you’re interested in learning more about the TPU architecturein detail, we recommend reading a collection of papers published over theyears. While many of them talk about specific TPU generations, many of theideas described transfer to later generations as well.

Noteworthy properties and restrictions#

BlockSpecs and grid iteration#

BlockSpecs (seeBlockSpec, a.k.a. how to chunk up inputs) generally behave as expectedin Pallas — every invocation ofthe kernel body gets access to slices of the inputs and is meant to initialize a sliceof the output.

Note

Not all block shapes are supported. On TPU, only blocks with rank at least 1

are supported. Furthermore, the last two dimensions of your block shapemust be divisible by 8 and 128 respectively, or be equal to the respectivedimensions of the overall array.

One interesting aspect of Pallas TPU kernels is the way they handle memory spaces:While the inputs topallas_call will often reside in HBM (the main TPUmemory), the references passed in to the kernel body will point to buffers inlower levels of memory hierarchy (VMEM or SMEM). This enables the kernel bodyto write and read them at very high speeds, while all the communication withHBM (which has very high latency) is handled by the compiler and overlappedwith compute.

What’s more, compared to GPUs, TPUs are actually highly sequential machines.Ergo, the grid is generally not processed in parallel, but sequentially,in lexicographic order (though see theMulticore TPU configurations sectionfor exceptions). This unlocks some interesting capabilities:

  • When two (lexicographically) consecutive grid indices use the same slice ofan input, the HBM transfer for the second iteration is skipped, as the data isalready available.

  • Multiple invocations of the kernel body can write to the same slice of theoutput, without any risk of race conditions. However, we do require that allinvocations that write to a particular slice are consecutive.

The “consecutive” restriction on the output usually means that some prefixof the grid dimensions always varies the slice of the output an invocation needsto access, while the output window remains constant for the remaining suffix.

For example, when implementing a Pallas TPU kernel for matrix multiplication,one would generally use a 3 dimensional grid: the first two dimensions wouldcorrespond to slicing along the first axis of the left operand and the secondaxis of the second operand. The third andlast grid axis would tile thereduction dimension. The grid axis corresponding to the reduction dimension hasto be the last one, since the output window does not vary along this axis.The output reference can be then used as an accumulator for partial results.

Note

VMEM is fairly large for such a low-level memory hierarchy (16MB+), making itpossible to use large window sizes. And, oftentimes, the larger the windowsize, the better the eventual hardware utilization will be. However, it is possible tospecify a window size that (together with space necessary to holdspilled vector registers) exceeds the size of VMEM. In this case, you will likely see alow-level compiler error message complaining about an out-of-memory error.

Array Layouts#

Dimension ordering of arrays is meaningful in Pallas.In JAX programs, the ordering of intermediate arrays insidejax.jit usuallyhas no impact on performance, as the compiler is free to rearrange them.However, as Pallas is meant to expose lower-level capabilities, the dimensionorder can have great impact on the quality of generated code.

TPUs perform the bulk of the computation on 2D vector registers, which are typically ofsize 8x128 for 32-bit values (as of TPU v6).When a vector value is loaded from VMEM into registers (e.g.x=x_ref[...]),the last two dimensions of the array will be tiled into the registers.Pallas will only ever consider mapping the last two dimensions ofintermediate arrays to the 8x128 vector register dimensions (sublanes and lanesrespectively).

Here is a graphical example of how a 12x320 array can be tiled using 6 8x128tiles:

../../_images/vector_layout_example.svg

Tiled layouts have several import ramifications for kernel writers:

  • The last two axes of an array are treated differently than otheraxes. For example, reductions, reshapes, and transposes are generallymore expensive when involving the last two axes. Some reshapesinvolving the last two dimensions are not supported and will result in a compilererror, but are “free” and performed at compile time for other dimensions.

  • While sometimes unavoidable, it is generally wasteful to have singletondimensions in the last two axes, since they will occupy 1 element out ofthe entire tile dimension. Consuming too many registers canalso potentially cause register spills into VMEM which degrades kernelperformance.

  • Related to the above point, all vector computation is padded up to the tilesize. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, andadding two 8x128x1x1 arrays will be 1024 times as expensive as adding two8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128.

Multicore TPU configurations#

In newer TPU generations, the two cores on a chip are often abstracted as asingle device. To take advantage of multiple cores, Pallas has to break thesequential grid execution guarantees, and will need to parallelize one of thegrid axes over cores. This is an opt-in procedure. To allow that,pallas_call requires an extra parameter nameddimension_semantics:

pallas_call(...,compiler_params=pltpu.CompilerParams(dimension_semantics=["parallel","parallel","arbitrary"]),)

That parameter is a list, with as many entries as many axes there are in thegrid. Onlyparallel dimensions can be partitioned over cores. As a rule ofthumb, the dimensions are parallel, unless the output window does not vary.As such,dimension_semantics is always a number ofparallel axesfollowed by a number ofarbitrary axes.

While partitioning a kernel over a 2-core TPU device often leads to a 2xspeedup, it can be in fact significantly smaller. This is especially true ifdifferent instances of the body have highly varying cost. If all of the expensivesteps get mapped to one core, but all cheap steps are assigned to the other, thesecond core will be sitting idle until the first one completes its tasks.

Pallas TPU generally favors partitioning axes of a size that is a multiple of thenumber of TPU cores, and prefers to partition leading grid axes.

Placing operands in SMEM#

Most of the compute on the TPU will happen on the vector unit. Still, there aremany cases where it is useful to perform a number of scalar operations, e.g., tocarry out control-flow. For that reason, TPUs come with a separatescalar unit, and a separate scalar memory (SMEM) attached to it.As a rule of thumb, any data used to perform control-flow decisions shouldbe placed in SMEM.

SMEM is a low-latency memory that supports random access, but lets you onlyread and write 32-bit values with a single instruction (very small compared tothe 4KBi granularity of VMEM transactions, but much more flexible due to lackof alignment requirements!).

The scalar memory is also very useful when implementing kernels that do notaccess the tiles of inputs in a regular pattern, such as when writingblock-sparse kernels. In Pallas, this can be achieved by replacing thegrid argument topallas_call with agrid_spec ofPrefetchScalarGridSpec with a non-zeronum_scalar_prefetch argument.Ifnum_scalar_prefetch isn, then the firstn arguments topallas_call will be placed in SMEM. NoBlockSpecs should be specifiedfor those arguments. But, theBlockSpecs for all subsequent arguments willreceive not only the grid indices, but also the SMEM references to the leadingoperands.

SeeScalar Prefetch and Block-Sparse Computation for examples on using thisfeature.

Supported data types#

At the moment Pallas TPU supports the following data types:

  • jnp.float32

  • jnp.bfloat16

  • jnp.int* (all precisions, except forjnp.int4)

  • jnp.uint* (all precisions)

  • jnp.bool_

Computation placement#

All scalar (i.e. 0D) arrays will be stored in scalar registers, and operationson then will be executed on the scalar core. All other operations (even onsingle-element, but 1D+ arrays) will be executed on the vector core.

Supported operations#

Matrix multiplication#

Matrix multiplication always produces results in the float32 format.If your inputs are not float32, we recommend usinglax.dot withpreferred_element_type set tojnp.float32.

When usinglax.dot_general, it is possible to fuse transpositions ofthe last two dimensions of matrix multiplication operands into the operation,which can improve overall kernel performance.

Precision control#

Pallas TPU lowering is aware ofjax.default_matmul_precision. For bestperformance (and lowest precision), usebfloat16. If you care aboutnumerical accuracy, you might want to set the precision tofloat32.

Warning

Even if you pass in 32-bit operands to a matrix multiplication, they will berounded tobfloat16 unlessfloat32 precision is requested.

Transposition#

If the value has at least 4 dimensions, arbitrary transpositions of all butthe last two axes are free.Otherwise, only the transposition of the last two axes is implemented.Note that some transpositions of the last two dimensions can be fused intomatrix multiplication.

Accessing memory#

Arbitrary slices of references can be read or updated, subject to implementationconstraints. Currently, no restrictions are placed on inputs that are 32-bit wide,but only some slicing patterns are supported for narrower types. Reads andwrites that are aligned to multiples of, and have a length that is a multipleof 8 and 128 respectively in the last two dimensions are always supported.

Reads and writes to vector memory generally happen on tiles of shape(8,128).As such, when reading or writing to references that have at least two dimensions,the best performance is achieved when the base offset of the memory accesshas indices divisible by the tiling, and the size of the read region is amultiple of the tile size.

Elementwise operations#

Many elementwise operations are supported. It is worth noting that the hardwaregenerally only supports elementwise computation using 32-bit types. When loadingoperands that use lower-precision types, they should generally be upcast to a32-bit type before applying elementwise ops.

It is worth noting that they can varysignificantly in their cost. As such, weoutline three categories of supported operations: cheap (🟢), medium (🌕) andexpensive (🔴).

Operation

Cost

jnp.add,+

🟢

jnp.sub,-

🟢

jnp.mul,*

🟢

/,//,%

🌕

jnp.max,jnp.min

🟢

jnp.where (select)

🟢

jnp.abs

🟢

|,^,&,~

🟢

<<,>>

🟢

Comparisons (==, …)

🟢

Type casts (.astype)

🟢

jnp.exp

🌕

jnp.tanh

🌕

jnp.pow

🌕

jnp.sin

🔴

jnp.cos

🔴

Many JAX functions are implemented in terms of other JAX primitives, so thislist might not be comprehensive. For example,jax.nn.relu is implementedin terms of comparisons andjnp.where will work in Pallas kernels too.

Array constructors#

All constant array constructors are supported (jnp.ones,jnp.zeros,jnp.full).

Reductions#

sum,max,min (for floating point values) reductions are supported, as wellasany andall for boolean values. Integer reductions are not supported.

Reductions over the last array dimension are generally the slowest.Reductions over the second last dimension are faster, but still slower thanover the leading dimensions.

Broadcasting#

The performance characteristics of broadcasting are very similar to thoseof reductions. Broadcasting along all but the two trailing dimensions isalways supported and free. Broadcasting along the second to last dimension isslower, while broadcasting along the last dimension is the slowest.

Reshapes#

As usual, reshapes in all dimensions but the last two dimensions are supportedand free.

The only two supported cases when a reshape can modify the last two dimensionsof an array is when (1) some leading dimensions are flattened onto the secondto last dimension, or (2) it adds a dimension that was just removed by areduction.

Random Number Generation#

Pallas supports the most commonly used functions from thejax.random module,such asuniform,normal, andbernoulli. The key should be athreefry2x32 key,which is the default setting in JAX. Keys can be directly passed into a kernel,or generated inside of a kernel.

Control flow#

The TPU backend features limited support for control flow at the moment. Thecurrently supported functions arecond,fori_loop andfor_loop.However, loop primitives get fully unrolled during the compilation at themoment, so try to keep the loop trip count reasonably small.

Overusing control flow can lead to significant regressions in low-level codegeneration, and it is recommended to try to squeeze as many computationallyexpensive operations into a single basic block as possible.


[8]ページ先頭

©2009-2026 Movatter.jp