Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Grids and BlockSpecs#

grid, a.k.a. kernels in a loop#

When usingjax.experimental.pallas.pallas_call() the kernel functionis executed multiple times on different inputs, as specified via thegrid argumenttopallas_call. Conceptually:

pl.pallas_call(some_kernel,grid=(n,))(...)

maps to

foriinrange(n):some_kernel(...)

Grids can be generalized to be multi-dimensional, corresponding to nestedloops. For example,

pl.pallas_call(some_kernel,grid=(n,m))(...)

is equivalent to

foriinrange(n):forjinrange(m):some_kernel(...)

This generalizes to any tuple of integers (a lengthd grid will correspondtod nested loops).The kernel is executed as many timesasprod(grid).The default grid value() results in onekernel invocation.Each of these invocations is referred to as a “program”.To access which program (i.e. which element of the grid) the kernel is currentlyexecuting, we usejax.experimental.pallas.program_id().For example, for invocation(1,2),program_id(axis=0) returns1 andprogram_id(axis=1) returns2.You can also usejax.experimental.pallas.num_programs() to get thegrid size for a given axis.

SeeGrids by example for a simple kernel that uses this API.

BlockSpec, a.k.a. how to chunk up inputs#

In conjunction with thegrid argument, we need to provide Pallasthe information on how to slice up the input for each invocation.Specifically, we need to provide a mapping betweenthe iteration of the looptowhich block of our inputs and outputs to be operated on.This is provided viajax.experimental.pallas.BlockSpec objects.

Before we get into the details ofBlockSpecs, you may wantto revisitBlock specs by example in Pallas Quickstart.

BlockSpecs are provided topallas_call via thein_specs andout_specs, one for each input and output respectively.

First, we discuss the semantics ofBlockSpec whenindexing_mode==pl.Blocked().

Informally, theindex_map of theBlockSpec takes as argumentsthe invocation indices (as many as the length of thegrid tuple),and returnsblock indices (one block index for each axis ofthe overall array). Each block index is then multiplied by thecorresponding axis size fromblock_shapeto get the actual element index on the corresponding array axis.

Note

Not all block shapes are supported.

  • On TPU, only blocks with rank at least 1 are supported.Furthermore, the last two dimensions of your block shape must be equal tothe respective dimension of the overall array, or be divisibleby 8 and 128 respectively. For blocks of rank 1, the block dimensionmust be equal to the array dimension, or be divisible by128*(32/bitwidth(dtype)).

  • On GPU, when using the Mosaic GPU backend, the size of the blocks isunrestricted. However, due to hardware limitations, the size of the minormostarray dimension must by such that it is a multiple of 16 bytes. For example,it must be a multiple of 8 if the input isjnp.float16.

  • On GPU, when using the Triton backend, the size of the blocks themselves isunrestricted, but each operation (including a load or store) must operateon arrays whose size is a power of 2.

If the block shape does not divide evenly the overall shape then thelast iteration on each axis will still receive references to blocksofblock_shape but the elements that are out-of-bounds are paddedon input and discarded on output. The values of the padding are unspecified, andyou should assume they are garbage. In theinterpret=True mode, wepad with NaN for floating-point values, to give users a chance tospot accessing out-of-bounds elements, but this behavior should notbe depended upon. Note that at least one of theelements in each block must be within bounds.

More precisely, the slices for each axis of the inputx ofshapex_shape are computed as in the functionslice_for_invocationbelow:

>>>importjax>>>fromjax.experimentalimportpallasaspl>>>defslices_for_invocation(x_shape:tuple[int,...],...x_spec:pl.BlockSpec,...grid:tuple[int,...],...invocation_indices:tuple[int,...])->tuple[slice,...]:...assertlen(invocation_indices)==len(grid)...assertall(0<=i<grid_sizefori,grid_sizeinzip(invocation_indices,grid))...block_indices=x_spec.index_map(*invocation_indices)...assertlen(x_shape)==len(x_spec.block_shape)==len(block_indices)...elem_indices=[]...forx_size,block_size,block_idxinzip(x_shape,x_spec.block_shape,block_indices):...start_idx=block_idx*block_size...# At least one element of the block must be within bounds...assertstart_idx<x_size...elem_indices.append(slice(start_idx,start_idx+block_size))...returnelem_indices

For example:

>>>slices_for_invocation(x_shape=(100,100),...x_spec=pl.BlockSpec((10,20),lambdai,j:(i,j)),...grid=(10,5),...invocation_indices=(2,4))[slice(20, 30, None), slice(80, 100, None)]>>># Same shape of the array and blocks, but we iterate over each block 4 times>>>slices_for_invocation(x_shape=(100,100),...x_spec=pl.BlockSpec((10,20),lambdai,j,k:(i,j)),...grid=(10,5,4),...invocation_indices=(2,4,0))[slice(20, 30, None), slice(80, 100, None)]>>># An example when the block is partially out-of-bounds in the 2nd axis.>>>slices_for_invocation(x_shape=(100,90),...x_spec=pl.BlockSpec((10,20),lambdai,j:(i,j)),...grid=(10,5),...invocation_indices=(2,4))[slice(20, 30, None), slice(80, 100, None)]

The functionshow_program_ids defined below uses Pallas to show theinvocation indices. Theiota_2D_kernel will fill each output blockwith a decimal number where the first digit represents the invocationindex over the first axis, and the second the invocation indexover the second axis:

>>>defshow_program_ids(x_shape,block_shape,grid,...index_map=lambdai,j:(i,j)):...defprogram_ids_kernel(o_ref):# Fill the output block with 10*program_id(1) + program_id(0)...axes=0...foraxisinrange(len(grid)):...axes+=pl.program_id(axis)*10**(len(grid)-1-axis)...o_ref[...]=jnp.full(o_ref.shape,axes)...res=pl.pallas_call(program_ids_kernel,...out_shape=jax.ShapeDtypeStruct(x_shape,dtype=np.int32),...grid=grid,...in_specs=[],...out_specs=pl.BlockSpec(block_shape,index_map),...interpret=True)()...print(res)

For example:

>>>show_program_ids(x_shape=(8,6),block_shape=(2,3),grid=(4,2),...index_map=lambdai,j:(i,j))[[ 0  0  0  1  1  1] [ 0  0  0  1  1  1] [10 10 10 11 11 11] [10 10 10 11 11 11] [20 20 20 21 21 21] [20 20 20 21 21 21] [30 30 30 31 31 31] [30 30 30 31 31 31]]>>># An example with out-of-bounds accesses>>>show_program_ids(x_shape=(7,5),block_shape=(2,3),grid=(4,2),...index_map=lambdai,j:(i,j))[[ 0  0  0  1  1] [ 0  0  0  1  1] [10 10 10 11 11] [10 10 10 11 11] [20 20 20 21 21] [20 20 20 21 21] [30 30 30 31 31]]>>># It is allowed for the shape to be smaller than block_shape>>>show_program_ids(x_shape=(1,2),block_shape=(2,3),grid=(1,1),...index_map=lambdai,j:(i,j))[[0 0]]

When multiple invocations write to the same elements of the outputarray the result is platform dependent.

In the example below, we have a 3D grid with the last grid dimensionnot used in the block selection (index_map=lambdai,j,k:(i,j)).Hence, we iterate over the same output block 10 times.The output shown below was generated on CPU usinginterpret=Truemode, which at the moment executes the invocation sequentially.On TPUs, programs are executed in a combination of parallel and sequential,and this function generates the output shown.SeeNoteworthy properties and restrictions.

>>>show_program_ids(x_shape=(8,6),block_shape=(2,3),grid=(4,2,10),...index_map=lambdai,j,k:(i,j))[[  9   9   9  19  19  19] [  9   9   9  19  19  19] [109 109 109 119 119 119] [109 109 109 119 119 119] [209 209 209 219 219 219] [209 209 209 219 219 219] [309 309 309 319 319 319] [309 309 309 319 319 319]]

ANone value appearing as a dimension value in theblock_shape behavesas the value1, except that the correspondingblock axis is squeezed (you could also pass inpl.Squeezed() instead ofNone). In the example below, observe that theshape of theo_ref is (2,) when the block shape was specified as(None,2) (the leading dimension was squeezed).

>>>defkernel(o_ref):...asserto_ref.shape==(2,)...o_ref[...]=jnp.full((2,),10*pl.program_id(1)+pl.program_id(0))>>>pl.pallas_call(kernel,...jax.ShapeDtypeStruct((3,4),dtype=np.int32),...out_specs=pl.BlockSpec((None,2),lambdai,j:(i,j)),...grid=(3,2),interpret=True)()Array([[ 0,  0, 10, 10],       [ 1,  1, 11, 11],       [ 2,  2, 12, 12]], dtype=int32)

When we construct aBlockSpec we can use the valueNone for theblock_shape parameter, in which case the shape of the overall arrayis used asblock_shape.And if we use the valueNone for theindex_map parameterthen a default index map function that returns a tuple of zeros isused:index_map=lambda*invocation_indices:(0,)*len(block_shape).

>>>show_program_ids(x_shape=(4,4),block_shape=None,grid=(2,3),...index_map=None)[[12 12 12 12] [12 12 12 12] [12 12 12 12] [12 12 12 12]]>>>show_program_ids(x_shape=(4,4),block_shape=(4,4),grid=(2,3),...index_map=None)[[12 12 12 12] [12 12 12 12] [12 12 12 12] [12 12 12 12]]

The “element” indexing mode#

The behavior documented above applies to the default “blocked” indexing mode.When integers are used in theblock_shape tuple e.g.(4,8), it isequivalent to passing in apl.Blocked(block_size) object instead, e.g.(pl.Blocked(4),pl.Blocked(8)). Blocked indexing mode means the indicesreturned byindex_map areblock indices. We can pass in objects other thanpl.Blocked to change the semantics ofindex_map, most notably,pl.Element(block_size)..When using thepl.Element indexing mode the values returned by theindex map function are used directly as the array indices, without firstscaling them by the block size.When using thepl.Element mode you can specify virtual paddingof the array as a tuple of low-high paddings for the dimension: thebehavior is as if the overall array is padded on input. No guaranteesare made for the padding values in element mode, similarly to the paddingvalues for the blocked indexing mode when the block shape does not divide theoverall array shape.

TheElement mode is currently supported only on TPUs.

>>># element without padding>>>show_program_ids(x_shape=(8,6),block_shape=(pl.Element(2),pl.Element(3)),...grid=(4,2),...index_map=lambdai,j:(2*i,3*j))    [[ 0  0  0  1  1  1]     [ 0  0  0  1  1  1]     [10 10 10 11 11 11]     [10 10 10 11 11 11]     [20 20 20 21 21 21]     [20 20 20 21 21 21]     [30 30 30 31 31 31]     [30 30 30 31 31 31]]>>># element, first pad the array with 1 row and 2 columns.>>>show_program_ids(x_shape=(7,7),...block_shape=(pl.Element(2,(1,0)),...pl.Element(3,(2,0))),...grid=(4,3),...index_map=lambdai,j:(2*i,3*j))    [[ 0  1  1  1  2  2  2]     [10 11 11 11 12 12 12]     [10 11 11 11 12 12 12]     [20 21 21 21 22 22 22]     [20 21 21 21 22 22 22]     [30 31 31 31 32 32 32]     [30 31 31 31 32 32 32]]

[8]ページ先頭

©2009-2026 Movatter.jp