Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Pallas Core-specific Programming#

In this guide, we explore usingpl.core_map to write Pallas kernels. Compared withpallas_call,core_map offers a few key characteristics:

  • Per-core level programming: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another.

  • Collectives:core_map explicitly models physical cores, so inter-core communication can be expressed safely.

  • Platform generic:core_map programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes.

This guide focuses on TPU. For how to usecore_map on GPU to achieve higher thread flexibility, check out ourPallas GPUcore_map tutorial.

Environment setup#

Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. aMegacore). Some TPUs (v5p, v6e, 7x) also containSparseCores, each of which consists of many subcores.

This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores.

fromfunctoolsimportpartialimportjaxfromjax.shardingimportNamedShardingfromjax.experimentalimportpallasasplfromjax.experimental.pallasimporttpuaspltpufromjax.experimental.pallasimporttpu_scasplscimportjax.numpyasjnpimportnumpyasnpnum_devices=jax.local_device_count()assertnum_devices>1,"Please run this notebook with more than one device."tpu_info=pltpu.get_tpu_info()# This notebook only runs on TPU.print(f"Running on{num_devices} TPU{tpu_info.chip_version} devices.")
Running on 4 TPU v5p devices.

In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension calledcore, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total.

# Mesh of devicesmesh=jax.make_mesh((jax.device_count(),),('device',))print(mesh)# Mesh of cores, within a JAX devicetc_mesh=pltpu.create_tensorcore_mesh('core')print(tc_mesh)num_devices=mesh.sizenum_cores=len(tc_mesh.devices)print(f"There are{num_devices} devices, and{num_cores} cores each.")
Mesh('device': 4, axis_types=(Explicit,))TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',))There are 4 devices, and 2 cores each.

A simple per-core kernel#

pl.core_map allows you to write per-core local code, just asjax.shard_map allows you to write per-device code.

In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs usingpltpu.async_copy.

Communication between cores

Before communicating between cores, it is good practice to perform a barrier (usingpltpu.semaphore_signal) to ensure resources have been allocated and both cores are at the same point during the program.

Once the cores are synchronized, usepltpu.make_async_remote_copy to send data between them. Thedevice_id keyword argument generically allows sending to any core on any device, but if you just pass in{'core':other_core_id}, it will perform a intra-device inter-core copy (the other axis names are held constant).

# This runs on every coredefswap_cores_kernel(in_hbm,out_hbm,in_vmem,scratch_vmem,out_vmem,sem,send_sem,recv_sem):core_index=jax.lax.axis_index('core')num_cores=jax.lax.axis_size('core')slc_size=in_hbm.shape[-1]//num_coresslc=pl.ds(core_index*slc_size,slc_size)# Copy in a core-dependent slice of the inputpltpu.async_copy(in_hbm.at[:,slc],in_vmem,sem).wait()# A barrier to make sure all cores have entered run_scoped.# You won't need this if not doing inter-core communications.dst_core=(core_index+1)%num_coressem0=pltpu.get_barrier_semaphore()pltpu.semaphore_signal(sem0,1,device_id={'core':dst_core})pltpu.semaphore_wait(sem0,1)# Swap data between core 0 and core 1the_copy=pltpu.make_async_remote_copy(in_vmem,scratch_vmem,send_sem,recv_sem,device_id={'core':dst_core},)the_copy.start()the_copy.wait()# Core-local computeout_vmem[...]=scratch_vmem[...]*2# Copy out the outputpltpu.async_copy(out_vmem,out_hbm.at[:,slc],sem).wait()

Once you have the local kernel:

  • Start your top-level JAX code with HBM refs, and allocate output refs if needed.

  • Usepl.core_map, which takes the TensorCore mesh, to start per-core programming.

    • You will needcollective_id for the barrier semaphore.

  • Insidepl.core_map, invokepl.run_scoped to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel.

input_shape=(32,256)local_vmem_shape=(32//num_devices,256//num_cores)in_spec=jax.P('device',None)sharding=NamedSharding(mesh,in_spec)@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=in_spec,out_specs=in_spec,check_vma=False)defswap_cores(x):# Get buffers out of the input and outputx_hbm_ref=jax.new_ref(x)o_hbm_ref=jax.new_ref(jax.lax.empty(x.shape,x.dtype))@pl.core_map(tc_mesh,compiler_params=pltpu.CompilerParams(collective_id=0))def_():pl.run_scoped(partial(swap_cores_kernel,x_hbm_ref,o_hbm_ref),*([pltpu.VMEM(local_vmem_shape,x.dtype)]*3),# VMEM allocations*([pltpu.SemaphoreType.DMA]*3),# semaphores)returno_hbm_ref[...]x=jax.random.normal(jax.random.key(0),input_shape,jnp.float32)x=jax.device_put(x,sharding)y=swap_cores(x)np.testing.assert_array_equal(y[:,128:],x[:,:128]*2)np.testing.assert_array_equal(y[:,:128],x[:,128:]*2)

Save the boilerplate#

You can use thepl.kernel decorator to wrap boilerplate such ascore_map,run_scoped, and output buffer allocation.

Note that this should run inside anyjax.shard_map you may have at the top level.

@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=in_spec,out_specs=in_spec,check_vma=False)defswap_cores(x):scratch_shapes=[pltpu.VMEM(local_vmem_shape,x.dtype)]*3+[pltpu.SemaphoreType.DMA]*3returnpl.kernel(swap_cores_kernel,out_shape=x,mesh=tc_mesh,scratch_shapes=scratch_shapes,compiler_params=pltpu.CompilerParams(collective_id=0))(x)y=swap_cores(x)np.testing.assert_array_equal(y[:,128:],x[:,:128]*2)np.testing.assert_array_equal(y[:,:128],x[:,128:]*2)

Pipelining withcore_map#

Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallasgrid andBlockSpec. To do pipelining insidecore_map, usepltpu.emit_pipeline inside the core-local kernel.

Automatically parallelize work amongst cores

The simple way is to annotate a block axis aspltpu.PARALLEL, and Pallas will automatically parallelize work along this axis. Bothpl.pallas_call andpltpu.emit_pipeline supports this, via argumentscore_axis anddimension_semantics. Thepallas_call example isin another guide, and theemit_pipeline case is shown below.

When thePARALLEL annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed).

Scratch shapes allocation

Note that in the example below, the top levelpl.run_scoped (wrapped insidekernel) did not allocate any VMEM scratch buffers. Instead,pltpu.emit_pipeline allocates its own scratch buffers in VMEM and use them for its multiple buffering.

defadd_one_body(in_vmem,out_vmem):out_vmem[...]=in_vmem[...]+1input_shape=(1024,1024)in_spec=jax.P('device',None)defadd_one_kernel(x_hbm_ref,o_hbm_ref):in_shape=x_hbm_ref.shapepltpu.emit_pipeline(add_one_body,grid=(in_shape[0]//8,in_shape[1]//128),in_specs=[pl.BlockSpec(block_shape=(8,128),index_map=lambdai,j:(i,j),)],out_specs=[pl.BlockSpec(block_shape=(8,128),index_map=lambdai,j:(i,j),)],core_axis_name='core',dimension_semantics=(pltpu.PARALLEL,pltpu.ARBITRARY),)(x_hbm_ref,o_hbm_ref)@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=in_spec,out_specs=in_spec,check_vma=False)defadd_one(x):returnpl.kernel(add_one_kernel,out_shape=x,mesh=tc_mesh,scratch_shapes=[])(x)x=jax.random.normal(jax.random.key(0),input_shape,jnp.float32)x=jax.device_put(x,NamedSharding(mesh,in_spec))y=add_one(x)np.testing.assert_array_equal(y,x+1)

Scalar prefetch#

The code below extends the kernel above but usesscalar prefetch and dynamic block indexing to select a specific sub-slice of the input.

This involves pre-allocating an SMEM buffer (via thepl.run_scoped call insidekernel) and populating the buffer using async_copy before the pipeline starts. Close over the dynamic index value inside theindex_map to use it.

Manually delegate work amongst cores

The code example below also shows howcore_map allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above.

To achieve that, customize yourindex_map to use the core index to work on different slices on different cores.

input_shape=(1024,1024)in_spec=jax.P('device',None)output_shape=(1024,512)defindexed_add_one_kernel(in_refs,out_refs,i_smem_ref):(x_hbm_ref,i_hbm_ref),o_hbm_ref=in_refs,out_refsin_shape=x_hbm_ref.shapepltpu.sync_copy(i_hbm_ref,i_smem_ref)core_idx=jax.lax.axis_index('core')core_slc_size=in_shape[0]//num_coresi_map=lambdai:core_idx*core_slc_size//8+i# split work among coresj_map=lambdaj:i_smem_ref[0]//128+j# use the prefetched offsetpltpu.emit_pipeline(add_one_body,grid=(core_slc_size//8,output_shape[1]//128),in_specs=[pl.BlockSpec(block_shape=(8,128),index_map=lambdai,j:(i_map(i),j_map(j)),)],out_specs=[pl.BlockSpec(block_shape=(8,128),index_map=lambdai,j:(i_map(i),j),)])(x_hbm_ref,o_hbm_ref)@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=(in_spec,jax.P()),out_specs=in_spec,check_vma=False)defindexed_add_one(x,index):out_shape=jax.ShapeDtypeStruct((x.shape[0],x.shape[1]//2),x.dtype)returnpl.kernel(indexed_add_one_kernel,out_shape=out_shape,mesh=tc_mesh,scratch_shapes=[pltpu.SMEM((1,),jnp.int32)])((x,index))xs=jax.random.normal(jax.random.key(0),input_shape,jnp.float32)xs=jax.device_put(xs,NamedSharding(mesh,in_spec))idx=256y=indexed_add_one(xs,jnp.array([idx]))np.testing.assert_array_equal(y,xs[:,idx:(idx+512)]+1)

Mapping over SparseCores#

TPU v5p contains 4SparseCores, which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.

Start with knowing the basic SparseCore specs of your chip, and create aVectorSubcoreMesh for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, andcore_map will run your code SPMD on each of them.

sc_info=pltpu.get_tpu_info().sparse_coreassertsc_infoisnotNoneprint(sc_info)sc_mesh=plsc.VectorSubcoreMesh(core_axis_name="core",subcore_axis_name="subcore",num_cores=sc_info.num_cores)sc_num_cores=sc_info.num_coressc_num_subcores=sc_info.num_subcores
SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8)

The code below is very similar to theadd_one_kernel we wrote earlier, except for a few differences:

  1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.

  2. SparseCore register computation allows smaller slices (4x16 max for int32), so you need nested loops to iterate the slice during computation phase.

input_shape=(4096,128)SC_REG_OP_SHAPE=(4,16)defsc_add_one_body(in_vmem,out_vmem):@pl.loop(0,in_vmem.shape[0],step=SC_REG_OP_SHAPE[0])def_reg_loop_0(c0):@pl.loop(0,in_vmem.shape[1],step=SC_REG_OP_SHAPE[1])def_reg_loop_1(c1):slc=(pl.ds(c0,SC_REG_OP_SHAPE[0]),pl.ds(c1,SC_REG_OP_SHAPE[1]))out_vmem[slc]=in_vmem[slc]+1defsc_add_one_kernel(x_hbm_ref,o_hbm_ref):in_shape=x_hbm_ref.shapecore_idx=jax.lax.axis_index('core')subcore_idx=jax.lax.axis_index("subcore")cm_idx=core_idx*sc_num_subcores+subcore_idx# index on the core_mapslc_size=in_shape[0]//(sc_num_subcores*sc_num_cores)index_map=lambdai,j:(pl.ds(pl.multiple_of(cm_idx*slc_size+i*8,8),8),j)pltpu.emit_pipeline(sc_add_one_body,grid=(slc_size//8,in_shape[1]//128),in_specs=[pl.BlockSpec(block_shape=(pl.BoundedSlice(8),128),index_map=index_map,)],out_specs=[pl.BlockSpec(block_shape=(pl.BoundedSlice(8),128),index_map=index_map,)])(x_hbm_ref,o_hbm_ref)@jax.jit@partial(jax.shard_map,mesh=mesh,in_specs=in_spec,out_specs=in_spec,check_vma=False)defsc_add_one(x):returnpl.kernel(sc_add_one_kernel,out_shape=x,mesh=sc_mesh,scratch_shapes=[])(x)x=jax.random.randint(jax.random.key(0),input_shape,0,64,jnp.int32)x=jax.device_put(x,NamedSharding(mesh,in_spec))y=sc_add_one(x)np.testing.assert_array_equal(y,x+1)

[8]ページ先頭

©2009-2026 Movatter.jp