Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Pallas Design#

In this document, we explain the initial Pallas design.This is a snapshot of some of the earlier design decisions madeand Pallas’s specific APIs might have changed since.

Introduction#

JAX is being used for a diverse set of workloads, from large scale machinelearning to scientific computing.JAX’s success story is as much a success story for XLA,the primary compiler that JAX targets – XLA compiles JAXprograms for accelerators and has enabled JAX to scale to the largest MLmodels.JAX describes logical computations in XLA’s representation, HLO.HLO describes how computations happen logically but not physically.Given a logical HLO computation, XLA decides how that computation is to beexecuted physically.For a wide variety of ML applications, XLA does a goodjob of compiling user programs but inevitably some users hit XLA’slimitations.In these cases, we need to provide an “escape hatch” to allowexperts to write hand-tuned kernels that outperform XLA at thatpoint in time.Furthermore, advances in ML systems research take some time to beincorporated into XLA and users often want to run ahead with them.Over time, the compiler can incorporate the optimizations that were provenout experimentally through hand-tuned kernels.

XLA does offer theCustomCall mechanism as an escape hatch, but itrequires users to write C++ and on GPU it requires users to learn theCUDA programming model.The CUDA programming model is arguably too low-level for many machinelearning GPU kernels, like matrix multiplication,and even expert users will have trouble using CUDA to implement efficientmatrix multiplication or multi-headed attention.Not only this, JAX users are usually familiar with Python and NumPy-stylearray programming which doesn’t involve writing any C++ or thinking aboutGPU parallelism.All popular machine learning frameworks share thisidea: manipulating (usually) arrays with high level operationslikematmul orconvolution.Unfortunately, this means implementing a custom operation viaCustomCallis a big investment, involving potentially learning C++ and/or GPUprogramming.

Triton, a GPU compiler builtand maintained by OpenAI, has taken the ML compiler world by storm.Triton offers the best of both worlds: an array-based programming modelfor GPU kernels. Triton is the primary code generation routefortorch.compile in PyTorch 2.0, via the Torch Inductor library.Triton actively hides some aspects of GPU programming in the name of amore accessible programming model that can be used from Python and togenerate optimized code from a higher-level representation.While GPUs are more flexible than what Triton offers, in the ML domain,Triton seems to be expressive enough for many applications.

In this document, we describe Pallas, an extension to JAX that enableskernel programming for both GPUs and TPUs using a Triton-like model.A JAX-based kernel language offers several advantages:

  • Although Triton exposes a TPU-like programming model to users,i.e. writing programs for tiles of arrays in L1-cache, it is specializedenough to GPU that we cannot directly compile Triton for TPU.For example, Triton offers atomic operations specifically meant tohandle parallel writes that don’t necessarily make sense on TPU.A higher level front end can abstract away details of the platformwhile surfacing just that tile-based programming model.The kernels will thus be portable across different hardware platforms.

  • JAX as a tracing-based frontend for numerical computing is bothmature and well-used.By embedding the kernel programming language in JAX itself,we can reuse JAX’s tracing infrastructure and provide aNumPy-like frontend that’s already familiar to users.

  • JAX transformations are key to its success, allowing users toexpress simple programs but transform them to achieve complexfunctionality.We can leverage the same transformations (vmap, jvp, etc.) totransform user-written kernels.

The open question is: is JAX a good fit for a kernel language at all?We think so.Triton demonstrates that an array programming language can bepractical for writing GPU kernels and JAX is just that.JAX has also proven to be a flexible front-end for compilers andfor program transformations.

We describe Pallas as follows: we first describe the ways in whichwe extend JAX to support writing custom kernels.We then show how we can lower Pallas to both Triton and Mosaic.We conclude by describing existing and potential ways to transformPallas kernels via JAX transformations.

Pallas lowering pathVisualization of Pallas lowering paths

Pallas: Extending JAX for kernels#

The key point we’d like to make is that Pallas is just JAX, with someextensions:

  1. Users now use reference types calledRefs in their JAX code.This gives users more precise control over memory access andlayout in JAX will more closely resemble physical layout.

  2. Users write their JAX programs using a subset of JAX primitives,along with a set of Pallas-specific primitives.

  3. Users embed their Pallas kernels in an outer JAX program via aspecialpallas_call higher-order function, that executes thekernel in a map. It is analogous topmap orshard_map,except with references to shared memory.

We’ll go over these three extensions one at a time, by example.

Note that these APIs are still experimental and subject to change.

Reference types#

Let’s look at an example Pallas program for adding two vectors:

importjaximportjax.numpyasjnpfromjax.experimentalimportpallasaspldefadd_kernel(x_ref,y_ref,o_ref):# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`sx=x_ref[:]y=y_ref[:]o_ref[:]=x+yx,y=jnp.arange(8),jnp.arange(8,16)add=pl.pallas_call(add_kernel,out_shape=jax.ShapeDtypeStruct((8,),jnp.int32))add(x,y)

Unlike a regular JAX program,add_kernel does not receive immutablearray arguments.Instead, it’s provided with references that can be read from andupdated in-place using NumPy-like syntax.Refs are not a Pallas-specific concept – they were introduced toJAX to represent stateful computations.However, we can leverage them when writing kernels that operate onmutable memory too.

Pallas kernels not only receiveRefs corresponding to the inputsto the kernel, but also receiveRefs for the outputs as well(specified inpallas_call viaout_shape).Refs are special types that cannot be passed into the usual set ofJAX primitives without being read from first.When you read from aRef you get a JAXArray type out, and youmust write anArray into aRef.

Reading from/writing into Refs#

Reading from aRef corresponds to loading an array into thelowest level of the memory hierarchy (L1-cache on GPU and vectorregisters on TPU). Writing into aRef is analogous.

deff(x_ref,o_ref):# Using vanilla Python indexingx=x_ref[0,2:5,:]# Or via Numpy advanced int indexingo_ref[jnp.arange(3),:]=x# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:deff(x_ref):# Assume x_ref is (8, 4) and we want to read out a (2, 3) slicex=x_ref[jnp.arange(2)[...,None],jnp.arange(3)[None,...]]

Writing toRefs can be done via analogous__setitem__ styleindexing.

Other forms of indexing (for example, dynamic slicing) can be doneviapallas.load andpallas.store, new JAX primitives designed tomake loading from/storing into memory easier.We’ll discuss these new primitives later.

Extending JAX with new Pallas primitives#

Because JAX was designed with HLO in mind, the set of JAX primitivesclosely mirrors the set of HLO operations.Targeting a new compiler (e.g. Triton or Mosaic) means we might needto supplement JAX’s primitives with new ones specific to the newcompiler.At the same time, we may not be able to lower all JAX primitives,so we need to restrict it to a subset.

Because Pallas was initially designed with Triton in mind,we offer a set of new primitives targeting the Triton programming model.As we’ll show later, we can lower these primitives to Mosaic as well.

pallas.load andpallas.store#

pallas.load andpallas.store are primitives that allow loadingfrom memory and storing into memory.Unlike__getitem__ and__setitem__ they are more flexible at thecost of being more verbose.Specifically, you can use thepallas.dynamic_slice (pallas.ds forshort) construct (which should maybe be upstreamed into JAX to beused with Ref__getitem__ and__setitem__).

deff(x_ref,o_ref):# Reading from memory via pallas.loadx=pl.load(x_ref,(0,slice(2,5),slice(None)))# Using integer indexing automatically broadcastsx=pl.load(x_ref,(0,2+jnp.arange(3),slice(None)))# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as wellpl.store(o_ref,(0,pl.ds(start=2,size=3),slice(None)),x)

pallas.load andpallas.store also support masking via the maskargument.

deff(x_ref,o_ref):# Reading from memory via pallas.loadidx=jnp.arange(8)mask=idx<5x=pl.load(x_ref,(idx,),mask=mask,other=float('-inf'))

Masking is important when doing out-of-bounds loads/stores.The operational semantics of masking can be compiler-determined(if we understand the documentation properly, Triton avoids the readfrom/write to memory if it’s masked).

pallas.program_id andpallas.num_programs#

As we’ll soon see, we’ll be executing the same Pallas kernels manytimes (either in parallel or in a pipeline depending on the backend).These new primitives tell us “where” we are in the execution of thekernel.

pallas.program_id takes in an axis argument, which tells us whichindex in an axis of a multidimensional grid this kernel is currentlyexecuting in (analogous tothreadId from CUDA programming orlax.axis_index injax.pmap).Note that we are currently borrowing the “program” terminology fromTriton and in the future we might want to change it to something morefamiliar to JAX users.

deff(x_ref,o_ref):i=pl.program_id(axis=0)# execution index in the first axis of the grido_ref[i]=jnp.exp(x_ref[i])

pallas.num_programs also takes in an axis and returns the grid sizefor that axis.

Note that whileprogram_id andnum_programs are Triton-specificterminology they are easily generalized to make sense on TPU as well.

Using a subset of JAX primitives in Pallas#

Because we’re writing kernels, not high-level HLO programs, some JAXprimitives may not be able to be represented in our underlyingsubstrate efficiently.However, we know we can support most elementwise operations,simple dot products, and JAX control flow.

While we haven’t yet mapped out exactly all the JAX primitives thatwe can support in Pallas kernels, we can certainly identify some thatare not easy to lower or are unlikely to be useful:

  • conv_general - convolution usually isn’t offered as primitive inthe underlying hardware.

  • gather/scatter - the underlying compiler may not supportnoncontiguous memory reads and writes

Executing Pallas kernels withpallas_call#

Now that we’ve written our Pallas kernels (a.k.a. JAX withRefs andthe extra Pallas primitives), how do we execute them on a GPU or TPU?We usepallas_call, a higher order function (akin tojax.jit andjax.pmap) that executes the kernel.

The signature ofpallas_call is as follows:

defpallas_call(kernel:Callable,out_shape:Sequence[jax.ShapeDtypeStruct],*,in_specs:Sequence[Spec],out_specs:Sequence[Spec],grid:Optional[Tuple[int,...]]=None)->Callable:...

When we provide a kernel topallas_call we provide additionalinformation. The first isout_shape which tells the kernel what theoutputs look like (pallas_call will pass aRef corresponding tothese into the kernel to be written to).The rest of the information (in_specs,out_specs, andgrid) areinformation about how the kernel will be scheduled on the accelerator.

The (rough) semantics forpallas_call are as follows:

defpallas_call(kernel,out_shape,*,in_specs,out_specs,grid):defexecute(*args):outputs=map(empty_ref,out_shape)grid_indices=map(range,grid)forindicesinitertools.product(*grid_indices):# Could run in parallel!local_inputs=[in_spec.transform(arg,indices)forarg,in_specinzip(args,in_specs)]local_outputs=[out_spec.transform(arg,indices)forarg,out_specinzip(outputs,out_specs)]kernel(*local_inputs,*local_outputs)# writes to outputsreturnexecute

Specifically,pallas_call will “loop” over grid iteration space,applying a transformation to the inputs and outputs specified viathein_specs andout_specs.In each iteration, the kernel will be called on the transformedinputs and outputs. Note that the “loop” over the iteration spacecould be executed in parallel (e.g. on GPU).pallas_call also provides no guarantees on the order of loopiterations over the iteration space, just that every member of theiteration space will be looped over.Compilers like Triton and Mosaic will have more specific operationalsemantics associated with the grid.

Transformation functions#

Thein_specs andout_specs arguments topallas_call allowinputs and outputs to be transformed in some way.The two options that Pallas offers right now are an identitytransformation (where inputs and outputs are left unchanged),andBlockSpecs, take fixed-size slices ofRefs determined by theloop index.

ABlockSpec takes anindex_map function and ablock_shape.Logically, it takes an array and slices it along each axis intoblock_shape sizes blocks.Theindex_map function takes loop indices (from the grid index set)and maps them to block indices.The transform function convertsRefs into logical views of theRef at the corresponding block.When we specifyNone in an entry in block_shape,that corresponds to “mapping” over that dimension,removing it from the block within the kernel.

classBlockSpec:index_map:Callable[[Tuple[Int,...]],Tuple[Int,...]]block_shape:Tuple[Optional[int],...]deftransform(self,ref,*loop_indices):block_indices=self.transform_function(loop_indices)# Returns a view of `ref` starting at `block_indices` of shape self.block_shape...

We could also imagine otherSpecs that are used withpallas_call,for example aSpec that corresponds to overlapping windows to, say,implement convolutions.

Immediate benefits of Pallas as a front-end#

By offering a JAX front-end for kernel writing, we can immediatelyreap some benefits.

More flexible front end#

The first is that JAX users are already accustomed to the benefits(and limitations) of programming with JAX and its tracing-basedtransformations.This means users can use closures and other familiar Python constructswhen writing Pallas kernels.This is unlike the existing AST-parsing-based Triton front end or theMLIR builders for Mosaic.For example, this makes Pallas far more amenable to templating thanTriton.

See this example of how we can use higher-order functions in Pythonto template a kernel.

defmake_kernel(eltwise_kernel):defadd(x_ref,y_ref,o_ref):x=pl.load(x_ref,())y=pl.load(y_ref,())pl.store(o_ref,(),eltwise_kernel(x+y))returnaddkernel1=make_kernel(lambdax:x*2)kernel2=make_kernel(jnp.exp)pl.pallas_call(kernel1,out_shape=x,grid=1)(1.,1.)pl.pallas_call(kernel2,out_shape=x,grid=1)(1.,1.)

Emulation mode#

By representing kernels as programs with JAX primitives and some newPallas primitives, we can also lower Pallas programs to StableHLOdirectly and compile/execute them with XLA.Specifically, apallas_call can be implemented as alax.scan overthe grid.This enables us to develop GPU or TPU kernels on any XLA-supportedplatform (even CPU!) and debug them using JAX/XLA debugging tools(likejax.debug.print).We can also use the more reliable and better tested XLA numerics toverify the correctness of the Triton and Mosaic compilers.One could also imagine perturbing thescan ordering to simulate theparallel reads and writes that happen on GPU.

GPU Examples#

Note all the following examples are for GPU only. They will require tweaks tothe block sizes to work on TPUs.

add#

We modify ouradd_kernel example to operate over (2,)-sized blocksusingBlockSpecs.

defadd_kernel(x_ref,y_ref,o_ref):# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`sx=x_ref[:]y=y_ref[:]o_ref[:]=x+yx,y=jnp.arange(8),jnp.arange(8,16)add=pl.pallas_call(add_kernel,out_shape=jax.ShapeDtypeStruct((8,),jnp.int32),in_specs=[pl.BlockSpec((2,),lambdai:i),pl.BlockSpec((2,),lambdai:i)],out_specs=pl.BlockSpec((2,),lambdai:i),grid=(4,))add(x,y)

Templated matmul#

In this example, we compute tiles of the output by doing an unrolledaccumulation over blocks of rows and columns from our input arrays.We inline an activation function into the body of the kernel using ahigher order function so we can emit a fused kernel.

defmatmul_kernel(x_ref,y_ref,o_ref,*,activation,block_k):acc=jnp.zeros((x_ref.shape[0],y_ref.shape[1]),jnp.float32)forkinrange(x_ref.shape[1]//block_k):x=x_ref[:,k*block_k:(k+1)*block_k]y=y_ref[k*block_k:(k+1)*block_k,:]acc+=x@yo_ref[:,:]=activation(acc).astype(o_ref.dtype)x,y=jnp.ones((512,256)),jnp.ones((256,1024))block_shape=128,256,128@partial(jax.jit,static_argnames=["block_shape","activation"])defmatmul(x,y,*,block_shape,activation):block_m,block_n,block_k=block_shapefused_matmul=pl.pallas_call(partial(matmul_kernel,block_k=block_k,activation=activation),out_shape=jax.ShapeDtypeStruct((x.shape[0],y.shape[1],),jnp.float32),in_specs=[pl.BlockSpec((block_m,x.shape[1]),lambdai,j:(i,0)),pl.BlockSpec((y.shape[0],block_n),lambdai,j:(0,j))],out_specs=pl.BlockSpec((block_m,block_n),lambdai,j:(i,j)),grid=(4,4),)returnfused_matmul(x,y)z=matmul(x,y,block_shape=block_shape,activation=jax.nn.gelu)

Lowering Pallas#

After users express their Pallas kernels, we lower them to differentrepresentations depending on the target backend.On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas toMosaic.

Lowering Pallas to Triton for GPU#

Lowering Pallas to Triton is easy because Pallas was designed withTriton as a target language in mind.The main differences between Pallas and Triton is that Triton doesn’thave a notion ofBlockSpecs and also uses pointers when doingmemory loads and stores as opposed to indices.

Triton supports pointers as an array element type in its languageand in Triton you can load from and store to arrays of pointers.In Pallas, when given a(4,5)-shapedRef,x_ref, and then dolikex_ref[3,2], we need to lower this to computing a Tritonpointer to the appropriate row-major position inx_ref (that is,doing 5 * 3 + 2 * 1).Similarly, when we lower slices to Triton, e.g.x_ref[4,:] we needto produce an array of pointers5*4+jnp.arange(3).

Other than that, lowering to Triton is fairly straightforward.JAX dot products can be lowered to Triton dot products and JAX unaryprimitives are lowered to their Triton equivalents.Triton’s atomic operations are lowered via new Pallas atomicprimitives.

Lowering Pallas to Mosaic for TPU#

Mosaic consumes (mostly) standard dialect MLIR and emits LLO to becompiled for TPU.Pallas can be lowered to Mosaic via translating JAX primitives toMLIR (mostly thevector andarith dialects).TheBlockSpecs can be converted into pipeline schedules(i.e. thetransform_funcs in Mosaic).

Transforming Pallas#

A natural question is how do JAX transformations interact with Pallaskernels?There are two main ways: transformations inside Pallas kernels andtransformations outside Pallas kernels.

Transformation inside Pallas kernels should actually “just work”,so long as we are able to lower the transformed code.For example, we could usejax.grad(jnp.sin)(...) inside of a JAXkernel because we can lower acos to both Triton and Mosaic.However, we might not be able to lower ajax.vmap(lax.dynamic_slice)because it could turn into a gather that we cannot lower.

Transformations of Pallas kernels from the outer JAX programs isperhaps the more interesting case. How do we handle things likevmap(pallas_call) andgrad(pallas_call)?

vmap-of-pallas_call#

vmap automatically vectorizes JAX programs. While kernel writers mightwant precise control over how a batched kernel will behave differentlyfrom its unbatched variant, we can offer a reasonable defaultvmaprule forpallas_call while offering thejax.custom_vmapcustomization mechanism. Whenpallas_call isvmap-ed, we augmentthepallas_call to have an extra grid dimension corresponding to thenew batch dimension and transform theBlockSpecs to handle indexingalong that dimension.

grad-of-pallas_call#

grad ofpallas_call enables automatic differentiation of kernels.jax.grad breaks down into applications of three distinct transforms:jvp,partial_eval andtranspose.In principle, we can reuse most of JAX’s infrastructure whenimplementing these rules forpallas_call (since it behaves much likeexisting JAX higher order primitives).

However, automatic differentiation of kernels can result in aperformance hit due to how memory access is transposed.If we write a GPU kernel with overlapping-and-parallel reads anddisjoint-but-parallel writes, we automatically transpose it into akernel that has overlapping-but-parallel writes (which are slow whendone atomically) and disjoint-and-parallel reads.To emit a kernel that better uses parallelism with shared memory,we would need to reorder loops and change how the kernel is vectorized.Unfortunately, we do not have a program representation amenable tothat in Pallas.A potential direction to automatically differentiating kernelsefficiently is to explore a different representation, perhaps onelike that in Dex.We could also look at how Enzyme approaches this problem.However, AD of Pallas kernels may still be useful for a class ofkernels that does transpose efficiently (for example elementwisekernels).

In general, though,jax.custom_vjp is a viable escape hatch toexpress Pallas kernels that work withjax.grad.

Other transformations#

We could imagine other JAX transformations applying to Pallas kernelsthat we haven’t explicitly explored yet.For example,checkify is a JAX transformation that does functionalerror handling.We could imagine usingcheckify with pallas_call to allow plumbingout error codes from GPU kernels that indicate if OOB access or NaNswere produced.

Another potential transformation to integrate with iscustom_partitioning to enable automatically partitionable kernels tobe used with pjit.


[8]ページ先頭

©2009-2025 Movatter.jp