Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commitc584cbe

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for 32x32b loads/stores of arbitrary TMEM layouts
If the array in registers uses a layout exactly equal to the TMEM layout,and each register is exactly 32-bit, then the whole load/store operationis a trivial copy of registers into TMEM.This also adds support for arbitrary bitwidths in 32x32b TMEM transfers.PiperOrigin-RevId: 837775689
1 parent03fd0cd commitc584cbe

File tree

4 files changed

+182
-63
lines changed

4 files changed

+182
-63
lines changed

‎jax/_src/pallas/mosaic_gpu/core.py‎

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,7 @@ def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout:
14321432

14331433
@dataclasses.dataclass(frozen=True)
14341434
classParameterizedLayout(SomeLayout):
1435-
layout_cls:Layout
1435+
layout_cls:Layout|TMEMLayout
14361436
args:Sequence[Any]
14371437
kwargs:Any
14381438

@@ -1473,6 +1473,7 @@ class Layout(SomeLayout, enum.Enum):
14731473
TCGEN05_TRANSPOSED=enum.auto()
14741474
TCGEN05_M64_COLLECTIVE=enum.auto()
14751475
TCGEN05_TMEM_NATIVE=enum.auto()
1476+
TCGEN05_M64_COLLECTIVE_NATIVE=enum.auto()
14761477

14771478
SMEM_GMEM_COPY=enum.auto()
14781479
TMA_GATHER_INDICES=enum.auto()
@@ -1525,6 +1526,8 @@ def check_no_args():
15251526
returnmgpu.TMEM_NATIVE_LAYOUT
15261527
caseLayout.TCGEN05_M64_COLLECTIVE:
15271528
returntcgen05.fa_m64_collective_layout(*args,**kwargs)# pytype: disable=missing-parameter
1529+
caseLayout.TCGEN05_M64_COLLECTIVE_NATIVE:
1530+
returntcgen05.tmem_m64_collective_layout(*args,**kwargs).as_tiled_layout()# pytype: disable=missing-parameter
15281531
caseLayout.SMEM_GMEM_COPY:
15291532
normalize_args=lambdashape,dtype,swizzle: (shape,dtype,swizzle)
15301533
shape,dtype,swizzle=normalize_args(*args,**kwargs)
@@ -1548,15 +1551,22 @@ def check_no_args():
15481551

15491552
classTMEMLayout(enum.Enum):
15501553
"""Layout for TMEM references."""
1554+
# TODO(apaszke): Remove the layout suffix.
15511555
SCALES_LAYOUT=enum.auto()
15521556
SPARSE_METADATA_LAYOUT=enum.auto()
1557+
M64_COLLECTIVE_LAYOUT=enum.auto()
15531558

1554-
defto_mgpu(self)->tcgen05.TMEMLayout:
1559+
def__call__(self,*args,**kwargs)->ParameterizedLayout:
1560+
returnParameterizedLayout(self,args,kwargs)
1561+
1562+
defto_mgpu(self,*args,**kwargs)->tcgen05.TMEMLayout:
15551563
matchself:
15561564
caseTMEMLayout.SCALES_LAYOUT:
1557-
returntcgen05.scales_layout()
1565+
returntcgen05.scales_layout(*args,**kwargs)
15581566
caseTMEMLayout.SPARSE_METADATA_LAYOUT:
1559-
returntcgen05.sparse_meta_layout()
1567+
returntcgen05.sparse_meta_layout(*args,**kwargs)
1568+
caseTMEMLayout.M64_COLLECTIVE_LAYOUT:
1569+
returntcgen05.tmem_m64_collective_layout(*args,**kwargs)# pytype: disable=missing-parameter
15601570

15611571

15621572
defTryClusterCancelResult(

‎jax/experimental/mosaic/gpu/tcgen05.py‎

Lines changed: 95 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,11 @@ def canonicalize(self) -> TMEMLayout:
902902
_check_canonical=False,
903903
)
904904

905+
defas_tiled_layout(self)->fa.TiledLayout:
906+
returnfa.TiledLayout(
907+
self.tiling,self.warp_dims,self.lane_dims,self.vector_dim
908+
)
909+
905910

906911
def_infer_tmem_load_registers_layout(
907912
tmem_layout:TMEMLayout,columns:int,packing:int
@@ -1115,23 +1120,47 @@ def slice(self, *idxs) -> TMEMRef:
11151120
)
11161121

11171122
defload(self,layout:fa.TiledLayout|None=None,is_signed:bool|None=None)->fa.FragmentedArray:
1118-
ifutils.bitwidth(self.dtype)notin {16,32}:
1119-
raiseNotImplementedError(f"Unsupported dtype:{self.dtype}")
11201123
packing=self.packing
11211124
iflayoutisNone:
11221125
layout=_infer_tmem_load_registers_layout(
11231126
self.layout,self.shape[1],packing
11241127
)
1128+
bitwidth=utils.bitwidth(self.dtype)
1129+
has_default_layout=self.layout==tmem_default_layout(packing=packing)
11251130
regs_shape=layout.registers_shape(self.shape)
11261131
ifregs_shape[0]!=1:# We'll need to issue multiple loads below.
11271132
raiseNotImplementedError("Loading multiple row tiles")
11281133
iflayout==LAYOUTandself.layout==tmem_default_layout(packing=packing):
11291134
registers=_load_32xcols(
11301135
self.address,self.shape[1],self.dtype,packing
11311136
).T.reshape(regs_shape)
1132-
eliflayout==TMEM_NATIVE_LAYOUTandself.layout==tmem_default_layout(packing=packing):
1137+
eliflayout==self.layout.as_tiled_layout()andpacking*bitwidth==32:
1138+
assertlen(layout.base_tile_shape)==2
1139+
# We could allow replicated dims in the input, but we'd need to divide the
1140+
# split factor computed below by the replication factor of the input.
1141+
assertnotany(isinstance(d,fa.Replicated)fordinlayout.warp_dims)
1142+
assertnotany(isinstance(d,fa.Replicated)fordinlayout.lane_dims)
1143+
warp_split_factor=math.prod(
1144+
d.timesifisinstance(d,fa.Replicated)else1
1145+
fordinlayout.remove_dimension(1).warp_dims
1146+
)
1147+
lane_split_factor=math.prod(
1148+
d.timesifisinstance(d,fa.Replicated)else1
1149+
fordinlayout.remove_dimension(1).lane_dims
1150+
)
1151+
split_factor=warp_split_factor*lane_split_factor
11331152
registers=_load_32xcols_native(
1134-
self.address,self.shape[1],self.dtype,packing
1153+
self.address,self.shape[1]//split_factor,self.dtype,packing,packing
1154+
).reshape(regs_shape)
1155+
# TODO(apaszke): Support the case where we have a long vector length in the
1156+
# FA more generally, not just for 2x32b.
1157+
# 16-bit types are special, because the store instruction can unpack them.
1158+
eliflayout==TMEM_NATIVE_LAYOUTandhas_default_layoutand (
1159+
(bitwidth==16andpacking==1)
1160+
or (bitwidth==32andlayout.vector_length==2)
1161+
):
1162+
registers=_load_32xcols_native(
1163+
self.address,self.shape[1],self.dtype,packing,TMEM_NATIVE_LAYOUT.vector_length
11351164
).reshape(regs_shape)
11361165
eliflayout==fa.WGMMA_LAYOUTandself.layout==tmem_half_lane_layout(self.shape[1],packing=packing):
11371166
# Load half the columns, since they are folded over lanes.
@@ -1157,8 +1186,6 @@ def load(self, layout: fa.TiledLayout | None = None, is_signed: bool | None = No
11571186
)
11581187

11591188
defstore(self,value:fa.FragmentedArray):
1160-
ifutils.bitwidth(self.dtype)notin {8,16,32}:
1161-
raiseNotImplementedError(f"Unsupported dtype:{self.dtype}")
11621189
ifnotisinstance(value,fa.FragmentedArray):
11631190
raiseTypeError(f"TMEM stores expect a FragmentedArray, got:{value}")
11641191
ifvalue.shape!=self.shape:
@@ -1171,27 +1198,38 @@ def store(self, value: fa.FragmentedArray):
11711198
f"Stored array has dtype{value.mlir_dtype}, but TMEM has dtype"
11721199
f"{self.dtype}"
11731200
)
1201+
ifnotisinstance(value.layout,fa.TiledLayout):
1202+
raiseTypeError(f"Stored array has layout{value.layout}, but TMEM stores expect a TiledLayout")
11741203
packing=self.packing
11751204
has_default_layout=self.layout==tmem_default_layout(packing=packing)
1205+
bitwidth=utils.bitwidth(self.dtype)
11761206
ifvalue.layout==LAYOUTandhas_default_layout:
11771207
_store_32xcols(
11781208
self.address,value.registers.T.reshape((4,-1)),packing
11791209
)
1180-
elif (
1181-
utils.bitwidth(self.dtype)==8
1182-
andvalue.layout==fa.tmem_native_layout(vector_length=packing)
1183-
andhas_default_layout
1210+
elifvalue.layout==self.layout.as_tiled_layout()andpacking*bitwidth==32:
1211+
_store_32xcols_native(self.address,value.registers.reshape(-1),packing)
1212+
# TODO(apaszke): Support the case where we have a long vector length in the
1213+
# FA more generally, not just for 2x32b.
1214+
# TODO(apaszke): Support a wider range of layouts when dealing with unpacking.
1215+
# 16-bit types are special, because the store instruction can unpack them.
1216+
elifvalue.layout==TMEM_NATIVE_LAYOUTandhas_default_layoutand (
1217+
(bitwidth==16andpacking==1)
1218+
or (bitwidth==32andvalue.layout.vector_length==2)
11841219
):
11851220
_store_32xcols_native(self.address,value.registers.reshape(-1),packing)
1186-
elifvalue.layout==TMEM_NATIVE_LAYOUTandhas_default_layout:
1187-
_store_32xcols_native(
1188-
self.address,value.registers.reshape(-1),packing
1189-
)
1190-
elifvalue.layout==fa.WGMMA_LAYOUTandself.layout==tmem_half_lane_layout(self.shape[1],packing=packing):
1221+
elif (
1222+
value.layout==fa.WGMMA_LAYOUT
1223+
andself.layout==tmem_half_lane_layout(self.shape[1],packing=packing)
1224+
):
11911225
registers=value.registers.T.reshape(2,-1)
11921226
registers=np.concatenate(np.split(registers,2,axis=1),axis=0)
11931227
_store_32xcols(self.address,registers,packing)
1194-
elifvalue.layout==fa_m64_collective_layout(self.shape[1])andself.layout==tmem_m64_collective_layout(self.shape[1],packing=packing):
1228+
elifvalue.layout==fa_m64_collective_layout(
1229+
self.shape[1]
1230+
)andself.layout==tmem_m64_collective_layout(
1231+
self.shape[1],packing=packing
1232+
):
11951233
_store_32xcols(self.address,value.registers.reshape(4,-1),packing)
11961234
else:
11971235
raiseValueError(
@@ -1306,37 +1344,49 @@ def _store_32xcols(base_addr, vector_regs, tmem_packing) -> None:
13061344
def_store_32xcols_native(base_addr,vector_regs,tmem_packing)->None:
13071345
i32=ir.IntegerType.get_signless(32)
13081346
assertvector_regs.ndim==1
1309-
cols=len(vector_regs)*TMEM_NATIVE_LAYOUT.vector_length
13101347
vec_ty=ir.VectorType(vector_regs.flat[0].type)
1311-
reg_packing=64//utils.bitwidth(vec_ty)
1312-
store_atom_shape= (32,reg_packing)
1348+
[vector_length]=vec_ty.shape
13131349
elt_bitwidth=utils.bitwidth(vec_ty.element_type)
1350+
reg_packing=32//elt_bitwidth
1351+
store_atom_shape= (32,reg_packing)
1352+
# TODO(apaszke): More general register splitting code, not just 2x32b.
13141353
ifreg_packing==1:
1315-
# Transform data such that each reg is 32 bits wide.
1316-
assertelt_bitwidth==32,elt_bitwidth
1317-
regs= [None]*cols
1318-
c0=arith.constant(i32,0)
1319-
c1=arith.constant(i32,1)
1320-
foridx,vreginenumerate(vector_regs):
1321-
regs[2*idx]=llvm.extractelement(vreg,c0)
1322-
regs[2*idx+1]=llvm.extractelement(vreg,c1)
1354+
ifvector_length==2:
1355+
# Transform data such that each reg is 32 bits wide.
1356+
regs= [None]* (len(vector_regs)*2)
1357+
c0=arith.constant(i32,0)
1358+
c1=arith.constant(i32,1)
1359+
foridx,vreginenumerate(vector_regs):
1360+
regs[2*idx]=llvm.extractelement(vreg,c0)
1361+
regs[2*idx+1]=llvm.extractelement(vreg,c1)
1362+
else:
1363+
regs= [utils.bitcast(r,i32)forrinvector_regs]
13231364
asserttmem_packing==1
13241365
unpack=False
13251366
elifreg_packing==2:
1367+
assertvector_length==2
13261368
# In this case, registers are already packed into 32-bit registers.
1327-
regs=vector_regs
1369+
regs=[utils.bitcast(r,i32)forrinvector_regs]
13281370
ifelt_bitwidth==16:
13291371
assert1<=tmem_packing<=2
13301372
unpack=tmem_packing==1
13311373
else:
1332-
iftmem_packing==1:
1374+
iftmem_packing==1andelt_bitwidth!=32:
13331375
raiseNotImplementedError(
13341376
f"Unsupported packing:{tmem_packing} for element type{elt_bitwidth}"
13351377
)
13361378
asserttmem_packing==32//elt_bitwidth
13371379
unpack=False
13381380
else:
1339-
raiseNotImplementedError(reg_packing)
1381+
iftmem_packing!=reg_packing:
1382+
raiseNotImplementedError(
1383+
f"Only{reg_packing} packing supported for bitwidth{elt_bitwidth},"
1384+
f" but got TMEM packing of{tmem_packing}"
1385+
)
1386+
assertutils.bitwidth(vec_ty)==32
1387+
regs= [utils.bitcast(r,i32)forrinvector_regs]
1388+
unpack=False
1389+
cols=len(regs)*reg_packing
13401390
it=_transfer_32xcols(base_addr,cols,store_atom_shape,tmem_packing,reg_packing)
13411391
foraddr_row_col,instr_num,lane_step,num_sliceinit:
13421392
assertlane_step==0
@@ -1393,21 +1443,23 @@ def _load_32xcols(base_addr, cols, dtype, tmem_packing) -> np.ndarray:
13931443
returnvector_regs
13941444

13951445

1396-
def_load_32xcols_native(base_addr,cols,dtype,tmem_packing)->np.ndarray:
1446+
def_load_32xcols_native(base_addr,cols,dtype,tmem_packing,vector_length)->np.ndarray:
13971447
i32=ir.IntegerType.get_signless(32)
1398-
vec_ty=ir.VectorType.get((2,),dtype)
1448+
vec_ty=ir.VectorType.get((vector_length,),dtype)
13991449
reg_packing=32//utils.bitwidth(dtype)
1450+
assertvector_length%reg_packing==0
14001451
load_shape="32x32b"
1401-
ifreg_packing==1:
1402-
load_atom_shape= (32,1)
1403-
asserttmem_packing==1
1404-
pack=False
1405-
elifreg_packing==2:
1406-
load_atom_shape= (32,2)
1452+
load_atom_shape= (32,reg_packing)
1453+
ifreg_packing==2:
14071454
assert1<=tmem_packing<=2
14081455
pack=tmem_packing==1
14091456
else:
1410-
raiseNotImplementedError(reg_packing)
1457+
iftmem_packing!=reg_packing:
1458+
raiseNotImplementedError(
1459+
f"Only{reg_packing} supported for element type{dtype}, but got"
1460+
f" TMEM packing of{tmem_packing}"
1461+
)
1462+
pack=False
14111463

14121464
it=_transfer_32xcols(base_addr,cols,load_atom_shape,tmem_packing,reg_packing)
14131465
c0=arith.constant(i32,0)
@@ -1416,24 +1468,22 @@ def _load_32xcols_native(base_addr, cols, dtype, tmem_packing) -> np.ndarray:
14161468
foraddr_row_col,instr_num,lane_step,num_sliceinit:
14171469
assertlane_step==0,lane_step
14181470
instr_regs=_tmem_load(addr_row_col,load_shape,instr_num,pack)
1419-
ifreg_packing==1:
1471+
ifreg_packing==1andvector_length==2:
14201472
regs[num_slice]= [llvm.bitcast(dtype,r)forrininstr_regs]
14211473
else:
1422-
assertreg_packing==2
1423-
regs[num_slice]= [llvm.bitcast(vec_ty,r)forrininstr_regs]
1474+
regs[num_slice]= [utils.bitcast(r,vec_ty)forrininstr_regs]
14241475

1425-
ifreg_packing==1:
1476+
ifreg_packing==1andvector_length==2:
14261477
vector_regs=np.ndarray((cols//2,),dtype=object)
14271478
undef=llvm.mlir_undef(vec_ty)
14281479
foridxinrange(vector_regs.size):
14291480
high_undef=llvm.insertelement(undef,regs[2*idx],c0)
14301481
vreg=llvm.insertelement(high_undef,regs[2*idx+1],c1)
14311482
vector_regs[idx]=vreg
14321483
else:
1433-
assertreg_packing==2
1484+
assertvector_length==reg_packing
14341485
vector_regs=np.asarray(regs,dtype=object)
14351486

1436-
assertvector_regs.shape== (cols//TMEM_NATIVE_LAYOUT.vector_length,)
14371487
returnvector_regs
14381488

14391489

‎tests/mosaic/gpu_test.py‎

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,18 +1199,33 @@ def setUp(self):
11991199
self.skipTest("Only works on GPU with capability sm_100a or sm_101a")
12001200

12011201
@parameterized.product(
1202-
jax_dtype_packing=[(jnp.float32,1), (jnp.float16,1), (jnp.float16,2)],
1202+
jax_dtype_packing=[(jnp.float32,1), (jnp.float16,1), (jnp.float16,2), (jnp.float8_e5m2,4)],
12031203
reg_tmem_layout_m=[
1204-
(lambda_:tcgen05.LAYOUT,lambda_,p:tcgen05.tmem_default_layout(p),128),
1205-
(lambda_:fa.WGMMA_LAYOUT,tcgen05.tmem_half_lane_layout,64),
1206-
(tcgen05.fa_m64_collective_layout,tcgen05.tmem_m64_collective_layout,64),
1204+
(lambda_c,_p:tcgen05.LAYOUT,lambda_,p:tcgen05.tmem_default_layout(p),128),
1205+
(lambda_c,_p:fa.WGMMA_LAYOUT,tcgen05.tmem_half_lane_layout,64),
1206+
(
1207+
lambdac,_p:tcgen05.fa_m64_collective_layout(c),
1208+
tcgen05.tmem_m64_collective_layout,
1209+
64,
1210+
),
1211+
(
1212+
lambdac,p:tcgen05.tmem_m64_collective_layout(c,p).as_tiled_layout(),
1213+
tcgen05.tmem_m64_collective_layout,
1214+
64,
1215+
),
12071216
],
12081217
)
12091218
deftest_load_store_tmem(self,jax_dtype_packing,reg_tmem_layout_m):
12101219
jax_dtype,packing=jax_dtype_packing
12111220
reg_layout_f,tmem_layout_f,m=reg_tmem_layout_m
12121221
n=160
1213-
reg_layout=reg_layout_f(n)
1222+
reg_layout=reg_layout_f(n,packing)
1223+
iftmem_layout_fistcgen05.tmem_m64_collective_layout:
1224+
ifjax_dtype==jnp.float16andpacking==1:
1225+
self.skipTest("Not implemented yet")
1226+
is_native_transfer=tmem_layout_f(n,packing).as_tiled_layout()==reg_layout
1227+
ifnotis_native_transferandjax_dtype==jnp.float8_e5m2:
1228+
self.skipTest("Not implemented yet")
12141229

12151230
defkernel(ctx,input,output,tmem):
12161231
delctx
@@ -1220,19 +1235,28 @@ def kernel(ctx, input, output, tmem):
12201235

12211236
x=self.prng.uniform(-1,1, (m,n)).astype(jax_dtype)
12221237
y=mgpu.as_gpu_kernel(
1223-
kernel, (1,1,1), (128,1,1),x,x,mgpu.TMEM(x.shape,jax_dtype,layout=tmem_layout_f(n,packing)),
1238+
kernel, (1,1,1), (128,1,1),x,x,
1239+
mgpu.TMEM(x.shape,jax_dtype,layout=tmem_layout_f(n,packing)),
12241240
)(x)
12251241
np.testing.assert_array_equal(x,y)
12261242

1227-
@parameterized.parameters([(jnp.float32,1), (jnp.float16,1), (jnp.float16,2)])
1243+
@parameterized.parameters([
1244+
(jnp.float32,1),
1245+
(jnp.float16,1),
1246+
(jnp.float16,2),
1247+
(jnp.float8_e5m2,4),
1248+
# TODO(apaszke): Enable. LLVM lowering doesn't like 4 bits yet.
1249+
# (jnp.float4_e2m1fn, 8),
1250+
])
12281251
deftest_load_store_tmem_native(self,jax_dtype,packing):
12291252
# TODO(bchetioui): add a test for int8 with a native layout with vector
12301253
# length equal to 4 once TMEM load is implemented for it.
12311254
defkernel(ctx,input,output,tmem):
12321255
delctx
1233-
tmem.store(fa.FragmentedArray.load_untiled(input,layout=tcgen05.TMEM_NATIVE_LAYOUT,optimized=False))
1256+
reg_layout=tcgen05.tmem_default_layout(max(packing,2)).as_tiled_layout()
1257+
tmem.store(fa.FragmentedArray.load_untiled(input,layout=reg_layout,optimized=False))
12341258
tcgen05.commit_tmem()
1235-
tmem.load(tcgen05.TMEM_NATIVE_LAYOUT).store_untiled(output,optimized=False)
1259+
tmem.load(reg_layout).store_untiled(output,optimized=False)
12361260

12371261
x=self.prng.uniform(-1,1, (128,128)).astype(jax_dtype)
12381262
y=mgpu.as_gpu_kernel(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp