Pallas Quickstart
Contents
Pallas Quickstart#
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.Pallas allows you to use the same JAX functions and APIs but operates at alower level of abstraction.
Specifically, Pallas requires users to think about memory access and how todivide up computations across multiple compute units in a hardware accelerator.On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.
Let’s dive into some examples.
Note: Pallas is still an experimental API and you may be broken by changes!
Hello world in Pallas#
fromfunctoolsimportpartialimportjaxfromjax.experimentalimportpallasasplimportjax.numpyasjnpimportnumpyasnp
We’ll first write the “hello world” in Pallas, a kernel that adds two vectors.
defadd_vectors_kernel(x_ref,y_ref,o_ref):x,y=x_ref[...],y_ref[...]o_ref[...]=x+y
Ref types
Let’s dissect this function a bit. Unlike most JAX functions you’ve probably written,it does not take injax.Arrays as inputs and doesn’t return any values.Instead, it takes inRef objects as inputs, which represent mutable buffers in memory.Note that we also don’t have any outputs but we are given ano_ref, which correspondsto the desired output.
Reading fromRefs
In the body, we are first reading fromx_ref andy_ref, indicated by the[...](the ellipsis means we are reading the wholeRef;alternatively we also could have usedx_ref[:]).Reading from aRef like this returns ajax.Array.
Writing toRefs
We then writex+y too_ref.Mutation has not historically been supported in JAX –jax.Arrays are immutable!Refs are new (experimental) types that allow mutation under certain circumstances.We can interpret writing to aRef as mutating its underlying buffer.
Indexing and SlicingRefs with.at
In addition to accessing the entire underlying buffer through a reference, itis possible to also access only a slice by using the.at property. Usingx_ref.at[slice] does not immediately read or write data; itcreates a newRef object that points to a slice of the original buffer. Forexampleref.at[0:128] creates a view of the first 128 elements;ref.at[::2]creates a strided view.
Once you have a newRef that represents a slice you can read it or write to itwith the usual syntax. Here is a simple example:
defadd_sliced_kernel(x_ref,y_ref,o_ref):small_mid=x_ref.shape[0]//2x_left=x_ref.at[:small_mid]x_right=x_ref.at[small_mid:]y_left=y_ref.at[:small_mid]y_right=y_ref.at[small_mid:]# The output shape is (4*small_mid).large_mid=2*small_mido_ref.at[:large_mid][:small_mid]=x_left[...]+y_left[...]o_ref.at[:large_mid][small_mid:]=x_left[...]+y_right[...]o_ref.at[large_mid:][:small_mid]=x_right[...]+y_left[...]o_ref.at[large_mid:][small_mid:]=x_right[...]+y_right[...]
Note that usingx_ref.at[slice][...] is equivalent tox_ref[slice]. The.at is useful if you want to compose multiple slices (e.g.x_ref.at[block_slice][thread_slice]) or if need to pass a slice to a subkernelfunction that takes aRef.
So we’ve written what we call a “kernel”, which we define as a program that willrun as an atomic unit of execution on an accelerator,without any interaction with the host.How do we invoke it from a JAX computation?We use thepallas_call higher-order function.
@jax.jitdefadd_vectors(x:jax.Array,y:jax.Array)->jax.Array:returnpl.pallas_call(add_vectors_kernel,out_shape=jax.ShapeDtypeStruct(x.shape,x.dtype))(x,y)add_vectors(jnp.arange(8),jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call lifts the Pallas kernel function into an operation that can be calledas part of a larger JAX program. But, to do so, it needs a few more details.Here we specifyout_shape, an object that has a.shape and.dtype (or a listthereof).out_shape determines the shape/dtype ofo_ref in ouradd_vector_kernel.
pallas_call returns a function that takes in and returnsjax.Arrays.
What’s actually happening here?
Thus far we’ve described how to think about Pallas kernels but what we’ve actuallyaccomplished is we’re writing a function that’s executed very close to the compute unitssince values are loaded into the innermost (fastest) portion of the memory hierarchy.
On GPU,x_ref corresponds to a value in high-bandwidth memory (HBM) and whenwe dox_ref[...] we are copying the value from HBM into static RAM (SRAM)(this is a costly operation generally speaking!).We then use GPU vector compute to execute the addition, then copy the resulting valuein SRAM back to HBM.
On TPU, we do something slightly different. Before the kernel is ever executed,we fetch the value from HBM into SRAM.x_ref therefore corresponds to a value inSRAM and when we dox_ref[...] we are copying the value from SRAM into a register.We then use TPU vector compute to execute the addition, then copy the resultingvalue back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.
We are in the process of writing backend-specific Pallas guides. Coming soon!
Pallas programming model#
In our “hello world” example, we wrote a very simple kernel.It takes advantage of the fact that our 8-sized arrays can comfortably fit insidethe SRAM of hardware accelerators.In most real-world applications, this will not be the case!
Part of writing Pallas kernels is thinking about how to take big arrays thatlive in high-bandwidth memory (HBM, also known as DRAM) and expressing computationsthat operate on “blocks” of those arrays that can fit in SRAM.
Grids by example#
To automatically “carve” up the inputs and outputs, you provide agrid andBlockSpecs topallas_call.
Agrid is a tuple of integers (e.g.(),(2,3,4), or(8,)) that specifiesan iteration space.For example, a grid(4,5) would have 20 elements:(0,0),(0,1),...,(0,4),(1,0),...,(3,4).We run the kernel function once for each element, a style of single-programmultiple-data (SPMD) programming.

A 2D grid
When we provide agrid topallas_call, the kernel is executed as many timesasprod(grid). Each of these invocations is referred to as a “program”.To access which program (i.e. which element of the grid) the kernel is currentlyexecuting, we useprogram_id(axis=...).For example, for invocation(1,2),program_id(axis=0) returns1 andprogram_id(axis=1) returns2.
Here’s an example kernel that uses agrid andprogram_id.
defiota_kernel(o_ref):i=pl.program_id(0)o_ref[i]=i
We now execute it usingpallas_call with an additionalgrid argument.On GPUs, we can call the kernel directly like so:
# GPU versiondefiota(size:int):returnpl.pallas_call(iota_kernel,out_shape=jax.ShapeDtypeStruct((size,),jnp.int32),grid=(size,))()iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
TPUs distinguish between vector and scalar memory spaces and in this case theoutput must be placed in scalar memory (MemorySpace.SMEM) sincei isa scalar. For more details readTPU and its memory spaces.To call the above kernel on TPU, run:
# TPU versionfromjax.experimental.pallasimporttpuaspltpudefiota(size:int):returnpl.pallas_call(iota_kernel,out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),out_shape=jax.ShapeDtypeStruct((size,),jnp.int32),grid=(size,))()iota(8)
Grid semantics#
On GPUs, each program is executed in parallel on separate threads.Thus, we need to think about race conditions on writes to HBM.A reasonable approach is to write our kernels in such a way that differentprograms write to disjoint locations in HBM to avoid these parallel writes.On the other hand, parallelizing the computation is how we can executeoperations like matrix multiplications really quickly.
In contrast, TPUs operate like a very wide SIMD machine.Some TPU models contain multiple cores, but in many cases a TPU can betreated as a single-threaded processor. The grid on a TPU can bespecified in a combination of parallel and sequential dimensions, where sequentialdimensions are guaranteed to run serially.
You can read more details atgrid, a.k.a. kernels in a loop andNoteworthy properties and restrictions.
Block specs by example#
Withgrid andprogram_id in mind, Pallas provides an abstraction thattakes care of some common indexing patterns seen in a lot of kernels.To build intuition, let’s try to implement a matrix multiplication.
A simple strategy for implementing a matrix multiplication in Pallas is toimplement it recursively.We know our underlying hardware has support for small matrix multiplications(using GPU and TPU tensorcores), so we just express a big matrix multiplicationin terms of smaller ones.
Suppose we have input matrices\(X\) and\(Y\) and are computing\(Z = XY\).We first express\(X\) and\(Y\) as block matrices.\(X\) will have “row” blocksand\(Y\) will have “column” blocks.
Our strategy is that because\(Z\) is also a block matrix, we can assign each ofthe programs in our Pallas kernel one of the output blocks.Computing each output block corresponds to doing a smaller matrix multiplybetween a “row” block of\(X\) and a “column” block of\(Y\).
To express this pattern, we useBlockSpecs. ABlockSpec specifies a blockshape for each input and output, and an “index map” function, that maps aset of program indices to a block index.

A visualization of aBlockSpec
For a concrete example, let’s say we’d like to multiply two(1024,1024)matricesx andy together to producez, and would like to parallelizethe computation 4 ways. We split upz into 4(512,512) blocks whereeach block is computed with a(512,1024)x(1024,512) matrix multiplication.To express this, we’d first use a(2,2) grid (one block for each program).
Forx, we useBlockSpec((512,1024),lambdai,j:(i,0)) – thiscarvesx up into “row” blocks.To see this, see how both program instances(1,0) and(1,1) pick the(1,0) block inx.Fory, we use a transposed versionBlockSpec((1024,512),lambdai,j:(0,j)).Finally, forz we useBlockSpec((512,512),lambdai,j:(i,j)).
TheseBlockSpecs are passed intopallas_call viain_specs andout_specs.
For more detail onBlockSpecs seeBlockSpec, a.k.a. how to chunk up inputs.
Underneath the hood,pallas_call will automatically carve up your inputs andoutputs intoRefs for each block that will be passed into the kernel.
defmatmul_kernel(x_ref,y_ref,z_ref):z_ref[...]=x_ref[...]@y_ref[...]defmatmul(x:jax.Array,y:jax.Array):returnpl.pallas_call(matmul_kernel,out_shape=jax.ShapeDtypeStruct((x.shape[0],y.shape[1]),x.dtype),grid=(2,2),in_specs=[pl.BlockSpec((x.shape[0]//2,x.shape[1]),lambdai,j:(i,0)),pl.BlockSpec((y.shape[0],y.shape[1]//2),lambdai,j:(0,j))],out_specs=pl.BlockSpec((x.shape[0]//2,y.shape[1]//2),lambdai,j:(i,j),))(x,y)k1,k2=jax.random.split(jax.random.key(0))x=jax.random.normal(k1,(1024,1024))y=jax.random.normal(k2,(1024,1024))z=matmul(x,y)np.testing.assert_allclose(z,x@y)
Note that this is a very naive implementation of a matrix multiplication butconsider it a starting point for various types of optimizations.Let’s add an additional feature to our matrix multiply: fused activation.It’s actually really easy! Just pass a higher-order activation function into the kernel.
defmatmul_kernel(x_ref,y_ref,z_ref,*,activation):z_ref[...]=activation(x_ref[...]@y_ref[...])defmatmul(x:jax.Array,y:jax.Array,*,activation):returnpl.pallas_call(partial(matmul_kernel,activation=activation),out_shape=jax.ShapeDtypeStruct((x.shape[0],y.shape[1]),x.dtype),grid=(2,2),in_specs=[pl.BlockSpec((x.shape[0]//2,x.shape[1]),lambdai,j:(i,0)),pl.BlockSpec((y.shape[0],y.shape[1]//2),lambdai,j:(0,j))],out_specs=pl.BlockSpec((x.shape[0]//2,y.shape[1]//2),lambdai,j:(i,j)),)(x,y)k1,k2=jax.random.split(jax.random.key(0))x=jax.random.normal(k1,(1024,1024))y=jax.random.normal(k2,(1024,1024))z=matmul(x,y,activation=jax.nn.relu)np.testing.assert_allclose(z,jax.nn.relu(x@y))
To conclude, let’s highlight a cool feature of Pallas: it composes withjax.vmap!To turn this matrix multiplication into a batched version, we just need tovmap it.
k1,k2=jax.random.split(jax.random.key(0))x=jax.random.normal(k1,(4,1024,1024))y=jax.random.normal(k2,(4,1024,1024))z=jax.vmap(partial(matmul,activation=jax.nn.relu))(x,y)np.testing.assert_allclose(z,jax.nn.relu(jax.vmap(jnp.matmul)(x,y)))
