Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Scalar Prefetch and Block-Sparse Computation#

In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.

importfunctoolsimporttimeitimportnumpyasnpimportjaxfromjaximportnumpyasjnpfromjaximportlaxfromjax.experimentalimportcheckifyfromjax.experimentalimportpallasasplfromjax.experimental.pallasimporttpuaspltpuassert"TPU"injax.devices()[0].device_kind,"Please run this notebook with TPU devices."print("Running on",jax.devices()[0].device_kind)
Running on TPU v5 lite

Dynamic Block Indexing with Scalar Prefetch#

We will be exploiting the “scalar prefetch” feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (“scalar memory”) that is loaded before the start of the pipeline (“prefetch”). Because this data is loaded before the pipeline, it is available for use in theindex_map for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.

To use scalar prefetch, usepltpu.PrefetchScalarGridSpec in place of the standardpl.GridSpec:

classPrefetchScalarGridSpec:def__init__(self,num_scalar_prefetch:int,grid:tuple[int,...],in_specs:PyTree[BlockSpec],out_specs:PyTree[BlockSpec],scratch_shapes:tuple[MemorySpace,...]):...

Thenum_scalar_prefetch parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetchRefs passed in to theindex_map and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to bothindex_map and kernel are always fixed and described below:

  • EachBlockSpec’sindex_map now expects the prefetchRefs to come after the grid indices:

defindex_map(*grid_indices,*prefetch_refs):...
  • The user-defined kernel expects prefetchRefs to come before the inputRefs. Additionally, the scratch refs come after the outputRefs.

defkernel(*prefetch_refs,*input_refs,*output_refs,*scratch_refs):...
  • When calling a new kernel usingpallas_call, the function returned bypallas_call also expects the scalar prefetch arguments to come before the inputs, e.g.

kernel=pl.pallas_call(...)result=kernel(*prefetch_args,*input_args)

Example: Block Dynamic Slice with Scalar Prefetch#

Let’s begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:

  1. Outside of the kernel, we compute the block index to extract as:block_idx=(start[0]//size[0],start[1]//size[1])

  2. We passblock_idx as a scalar prefetch argument intopallas_call.

  3. In our index map, we use the block index to select the corresponding block by returning(block_idx[0],block_idx[1]).

Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices.

defdynamic_slice_kernel(indices,x_ref,o_ref):delindiceso_ref[...]=x_ref[...]@checkify.checkify@functools.partial(jax.jit,static_argnums=(2,))defblock_dynamic_slice(x,starts,sizes):grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1,grid=(1,1),in_specs=[pl.BlockSpec(sizes,lambdai,j,block_idx:(block_idx[0],block_idx[1]))],out_specs=pl.BlockSpec(sizes,lambda*_:(0,0)),)kernel=pl.pallas_call(dynamic_slice_kernel,grid_spec=grid_spec,out_shape=jax.ShapeDtypeStruct(shape=sizes,dtype=x.dtype),)# Checkify inserts a runtime assert that starts are divisible by block size.checkify.check(starts[0]%sizes[0]==0,"Starts must be divisible by size.")checkify.check(starts[1]%sizes[1]==0,"Starts must be divisible by size.")block_idx=jnp.array([starts[0]//sizes[0],starts[1]//sizes[1]])returnkernel(block_idx,x)shape=(512,512)x=jnp.reshape(jnp.arange(np.prod(shape),dtype=jnp.int32),shape)err,result=block_dynamic_slice(x,starts=(128,256),sizes=(128,128))err.throw()ref=lax.dynamic_slice(x,start_indices=(128,256),slice_sizes=(128,128))diff=jnp.max(jnp.abs(result-ref))print("Error |result - lax.dynamic_slice| =",diff)
Error |result - lax.dynamic_slice| = 0

Sparse Kernels: Representing Sparse Data#

Before we dive into implementing sparse kernels, let’s first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of(block_index,block_data) pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.

The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.

block_coo

We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis.

defgenerate_block_sparse_mat(key,M,N,blk_M,blk_N,p=0.2,dtype=jnp.float32):"""Returns a sampled matrix and its block-sparse representation.  Args:    key: RNG Key.    M: Major array dimension.    N: Minor array dimension.    blk_M: Block size along M dimension.    blk_N: Block size along N dimension.    p: Probability that a block will be non-zero.    dtype: dtype of the sampled matrix.  Returns:    dense_mat: A (M, N) dense sampled array.    block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing      the non-zero blocks of the matrix.    indices_i: A (num_blocks,) array of block indices for the first axis.    indices_j: A (num_blocks,) array of block indices for the second axis.  """mask_key,blocks_key=jax.random.split(key)num_blocks=(M//blk_M,N//blk_N)# We first sample a block mask, denoting which blocks are nonzero.block_mask=jax.random.bernoulli(mask_key,p=p,shape=num_blocks)num_blocks=jnp.sum(block_mask)indices=jnp.where(block_mask)# For each non-zero block, we sample a block of random values.block_data=jax.random.uniform(blocks_key,shape=(num_blocks,blk_M,blk_N),dtype=dtype)# For checking purposes, create the dense version of the sparse matrix.dense_mat=jnp.zeros((M,N),dtype=dtype)forblkinrange(num_blocks):idx_i=indices[0][blk]idx_j=indices[1][blk]slice_i=slice(idx_i*blk_M,(idx_i+1)*blk_M)slice_j=slice(idx_j*blk_N,(idx_j+1)*blk_N)dense_mat=dense_mat.at[slice_i,slice_j].set(block_data[blk])returndense_mat,block_data,indices[0],indices[1]

Example: Sparse @ Dense Matrix Multiplication#

In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output.

We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:

sparse_matmul

It is important that we group the block indices by row (e.g.[0,0,1,2,3,3]) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the outputRef on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct.

M=N=K=16384blk_M=blk_N=blk_K=512defdsd_kernel(idxs_i_ref,idxs_k_ref,# Scalar prefetch inputs.x_ref,y_ref,_,o_ref,# Kernel inputs.accum_scratch,):"""A DSD (Dense = Sparse @ Dense) matmul kernel."""delidxs_k_refblk_idx=pl.program_id(1)is_start=blk_idx==0changed_blocks=(idxs_i_ref[blk_idx]!=idxs_i_ref[jnp.maximum(blk_idx-1,0)])@pl.when(is_start|changed_blocks)def_():accum_scratch[...]=jnp.zeros_like(accum_scratch)accum_scratch[...]+=jnp.dot(x_ref[0,:,:],y_ref[...],preferred_element_type=jnp.float32)next_block_change=(idxs_i_ref[blk_idx]!=idxs_i_ref[jnp.minimum(blk_idx+1,num_blocks)])is_end=blk_idx==(num_blocks-1)@pl.when(is_end|next_block_change)def_():o_ref[...]=accum_scratch[...].astype(o_ref.dtype)defx_map(j,blk_idx,blk_idxs_i,blk_idxs_k):delj,blk_idxs_i,blk_idxs_kreturn(blk_idx,0,0)defy_map(j,blk_idx,blk_idxs_i,blk_idxs_k):delblk_idxs_ireturn(blk_idxs_k[blk_idx],j)defo_map(j,blk_idx,blk_idxs_i,blk_idxs_k):delblk_idxs_kreturn(blk_idxs_i[blk_idx],j)(X_dense,X_blocks,indices_i,indices_k)=generate_block_sparse_mat(jax.random.key(0),M,K,blk_M,blk_K,p=0.1,dtype=jnp.bfloat16)num_blocks=X_blocks.shape[0]Y=jax.random.uniform(jax.random.key(1),shape=(K,N),dtype=jnp.bfloat16)zeros=jnp.zeros((M,N),dtype=jnp.bfloat16)out_shape=jax.ShapeDtypeStruct((M,N),dtype=jnp.bfloat16)grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=2,# Note that while num_blocks is static here, Pallas does support# dynamic grid sizes.grid=(N//blk_N,num_blocks),in_specs=[pl.BlockSpec((1,blk_M,blk_K),x_map),pl.BlockSpec((blk_K,blk_N),y_map),# Placeholder for a zeros-array used by input_output_aliases.pl.BlockSpec((blk_M,blk_N),o_map),],out_specs=pl.BlockSpec((blk_M,blk_N),o_map),scratch_shapes=[pltpu.VMEM((blk_M,blk_N),dtype=jnp.float32)])kernel=pl.pallas_call(dsd_kernel,grid_spec=grid_spec,out_shape=out_shape,# We use input-output aliases to zero-out o_ref for blocks that we never# visit. By passing in an array of zeros we avoid having o_ref start with# uninitialized values.input_output_aliases={4:0},# Map zeros to o_ref.)args=(indices_i,indices_k,X_blocks,Y,zeros)result=kernel(*args)ref=X_dense@Ydiff=jnp.abs(ref-result)print('mean |result - ref|:',jnp.mean(diff))
mean |result - ref|: 0

We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.

There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:

  • Usingdtype=jnp.bfloat16 is critical for performance since it reduces memory bandwidth by half.

  • Using larger block sizes also helps, since matrix multiply is an\(O(N^3)\) compute and\(O(N^2)\) memory operation. As\(N\) grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully.

# Benchmark Sparse Pallas kernel vs reference JAX implementationdefbenchmark(f,ntrials:int=100):defrun(*args,**kwargs):# Compile function firstjax.block_until_ready(f(*args,**kwargs))# Time functionresult=timeit.timeit(lambda:jax.block_until_ready(f(*args,**kwargs)),number=ntrials)time=result/ntrialsreturntimereturnrunn_trials=100pallas_impl=lambda*args:kernel(*args)time=benchmark(pallas_impl,n_trials)(indices_i,indices_k,X_blocks,Y,zeros)print("Sparse Kernel:%.3f ms (avg over%d trials)"%(time*1000,n_trials))ref_impl=jax.jit(lambdax,y:x@y)time=benchmark(ref_impl,n_trials)(X_dense,Y)print("Reference:%.3f ms (avg over%d trials)"%(time*1000,n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)Reference: 46.953 ms (avg over 100 trials)

Sparse Access Patterns on Dense Data#

In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (num_blocks).

A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located injax/experimental/pallas/ops/tpu, or in PyTorch’sFlexAttention.

The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling theindex_map for eachBlockSpec on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching thenext block that we are not skipping. In order to do this, we need to constructprefetch maps which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.

prefetch_map

Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.

Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in theindex_map function of the BlockSpec.

defmask_index_map(prefetch_map,i,j,...):next_nonzero_block=prefetch_map[i,j]return(next_nonzero_block,0,0)

We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps.

Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask#

In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.

As we will be working with a sparse mask, we will begin by implementing a function that converts anNxM mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, oursparsify_mask function computes:

  • Ablock_mask of shape(num_N_blocks,num_M_blocks) indicating if a block is all-zeros (value0) or contains non-zero elements (value1). If theblock_mask has a value of 0 we can skip computing the block in the kernel.

  • Aprefetch_mask array of shape(num_N_blocks,num_M_blocks) consisting of indices intomask_data for the next non-zero block.

  • Aprefetch_i array of shape(num_N_blocks,num_M_blocks) consisting of the next non-maskedi index of the mask.

  • Aprefetch_j array of shape(num_N_blocks,num_M_blocks) consisting of the next non-maskedj index of the mask.

  • Amask_data array of shape(num_blocks,blk_N,blk_M) containing data for non-zero blocks of the mask.

defsparsify_mask(mask:jax.Array,block_shape:tuple[int,int]):"""Preprocesses a mask into a sparse representation.  Args:    mask: A boolean array of shape [M, N]    block_shape: The size of a single block.  Returns:    block_mask: A block_shape array of booleans indicating whether a block      is all-zeros (0) or contains non-zero elements (1).    prefetch_mask: A block_shape array of integers indicating the index of the      next non-zero block.    mask_data: A (num_blocks, block_shape) array containing      the data for non-zero blocks of the mask.  """M,N=mask.shapebm,bn=block_shapeblock_mask=jnp.zeros((M//bm,N//bn),dtype=mask.dtype)mask_types_finder=[]mask_data=[]next_mask_type_idx=0prefetch_mask=jnp.zeros_like(block_mask)next_i=(M//bm)-1next_j=(N//bn)-1prefetch_i=jnp.zeros_like(block_mask)prefetch_j=jnp.zeros_like(block_mask)foriinrange(M//bm,-1,-1):forjinrange(N//bn,-1,-1):mask_block=mask[i*bm:(i+1)*bm,j*bn:(j+1)*bn]is_nonzero=jnp.any(mask_block)ifis_nonzero:try:type_index=mask_types_finder.index(str(mask_block))exceptValueError:type_index=len(mask_types_finder)mask_types_finder.append(str(mask_block))mask_data.append(mask_block)next_mask_type_idx=type_indexnext_i=inext_j=jelse:type_index=-1block_mask=block_mask.at[i,j].set(is_nonzero)prefetch_mask=prefetch_mask.at[i,j].set(next_mask_type_idx)prefetch_i=prefetch_i.at[i,j].set(next_i)prefetch_j=prefetch_j.at[i,j].set(next_j)returnblock_mask,prefetch_mask,prefetch_i,prefetch_j,jnp.stack(mask_data)

In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over theN,M, andK dimensions. Within the kernel itself, we first check theblock_mask to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result.

M=N=K=16384blk_M=blk_N=512blk_K=1024defsparse_mask_matmul(block_mask_ref,prefetch_mask,prefetch_i,prefetch_j,# Scalar prefetch inputs.x_ref,y_ref,mask_ref,o_ref,# Kernel inputs.accum_scratch):delprefetch_mask,prefetch_i,prefetch_ji,j,k=pl.program_id(0),pl.program_id(1),pl.program_id(2)should_compute=block_mask_ref[i,j]!=0@pl.when(k==0)def_():o_ref[...]=jnp.zeros_like(o_ref)accum_scratch[...]=jnp.zeros_like(accum_scratch[...])# We only compute the output for blocks with non-zero masks.# Otherwise we skip the computation entirely.@pl.when(should_compute)def_():result=jnp.dot(x_ref[...],y_ref[...],preferred_element_type=jnp.float32)accum_scratch[...]+=result@pl.when(k==pl.num_programs(2)-1)def_():o_ref[...]=(mask_ref[0,...]*accum_scratch[...]).astype(o_ref.dtype)X=jax.random.normal(jax.random.key(0),shape=(M,K),dtype=jnp.bfloat16)Y=jax.random.normal(jax.random.key(1),shape=(K,N),dtype=jnp.bfloat16)mask=jnp.ones((M,N),dtype=jnp.int32)mask=jnp.tril(mask)block_mask,prefetch_mask,prefetch_i,prefetch_j,sparse_mask_data=sparsify_mask(mask,(blk_M,blk_N))defx_map(i,j,k,block_mask,prefetch_mask,prefetch_i,prefetch_j):delprefetch_mask,prefetch_j# Zero-out the k index if the mask is zero, to avoid constantly fetching# new blocks in the inner loop for blocks we are skipping.k_fetch=(block_mask[i,j]!=0)*kreturn(prefetch_i[i,j],k_fetch)defy_map(i,j,k,block_mask,prefetch_mask,prefetch_i,prefetch_j):delprefetch_mask,prefetch_ik_fetch=(block_mask[i,j]!=0)*kreturn(k_fetch,prefetch_j[i,j])defmask_map(i,j,k,block_mask,prefetch_mask,*_):delk,block_maskreturn(prefetch_mask[i,j],0,0)defo_map(i,j,k,*_):delkreturn(i,j)grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=4,grid=(M//blk_M,N//blk_N,K//blk_K),in_specs=[pl.BlockSpec((blk_M,blk_K),x_map),pl.BlockSpec((blk_K,blk_N),y_map),pl.BlockSpec((1,blk_M,blk_N),mask_map)],out_specs=pl.BlockSpec((blk_M,blk_N),o_map),scratch_shapes=[pltpu.VMEM((blk_M,blk_N),dtype=jnp.float32)])kernel=pl.pallas_call(sparse_mask_matmul,grid_spec=grid_spec,out_shape=jax.ShapeDtypeStruct((M,N),jnp.bfloat16),)args=(block_mask,prefetch_mask,prefetch_i,prefetch_j,X,Y,sparse_mask_data)result=kernel(*args)ref=mask*(X@Y)diff=jnp.abs(ref-result)print('mean |result - ref|:',jnp.mean(diff))
mean |result - ref|: 1.0252e-05

Now let’s compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.

We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don’t exactly reach theoretical performance are:

  • We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.

  • The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger.

n_trials=100pallas_impl=lambda*args:kernel(*args)time=benchmark(pallas_impl,n_trials)(block_mask,prefetch_mask,prefetch_i,prefetch_j,X,Y,sparse_mask_data)print("Sparse Kernel:%.3f ms (avg over%d trials)"%(time*1000,n_trials))ref_impl=jax.jit(lambdamask,x,y:mask*(x@y))time=benchmark(ref_impl,n_trials)(mask,X,Y)print("Reference:%.3f ms (avg over%d trials)"%(time*1000,n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)Reference: 49.988 ms (avg over 100 trials)

[8]ページ先頭

©2009-2025 Movatter.jp