@@ -902,6 +902,11 @@ def canonicalize(self) -> TMEMLayout:
902902_check_canonical = False ,
903903 )
904904
905+ def as_tiled_layout (self )-> fa .TiledLayout :
906+ return fa .TiledLayout (
907+ self .tiling ,self .warp_dims ,self .lane_dims ,self .vector_dim
908+ )
909+
905910
906911def _infer_tmem_load_registers_layout (
907912tmem_layout :TMEMLayout ,columns :int ,packing :int
@@ -1115,23 +1120,47 @@ def slice(self, *idxs) -> TMEMRef:
11151120 )
11161121
11171122def load (self ,layout :fa .TiledLayout | None = None ,is_signed :bool | None = None )-> fa .FragmentedArray :
1118- if utils .bitwidth (self .dtype )not in {16 ,32 }:
1119- raise NotImplementedError (f"Unsupported dtype:{ self .dtype } " )
11201123packing = self .packing
11211124if layout is None :
11221125layout = _infer_tmem_load_registers_layout (
11231126self .layout ,self .shape [1 ],packing
11241127 )
1128+ bitwidth = utils .bitwidth (self .dtype )
1129+ has_default_layout = self .layout == tmem_default_layout (packing = packing )
11251130regs_shape = layout .registers_shape (self .shape )
11261131if regs_shape [0 ]!= 1 :# We'll need to issue multiple loads below.
11271132raise NotImplementedError ("Loading multiple row tiles" )
11281133if layout == LAYOUT and self .layout == tmem_default_layout (packing = packing ):
11291134registers = _load_32xcols (
11301135self .address ,self .shape [1 ],self .dtype ,packing
11311136 ).T .reshape (regs_shape )
1132- elif layout == TMEM_NATIVE_LAYOUT and self .layout == tmem_default_layout (packing = packing ):
1137+ elif layout == self .layout .as_tiled_layout ()and packing * bitwidth == 32 :
1138+ assert len (layout .base_tile_shape )== 2
1139+ # We could allow replicated dims in the input, but we'd need to divide the
1140+ # split factor computed below by the replication factor of the input.
1141+ assert not any (isinstance (d ,fa .Replicated )for d in layout .warp_dims )
1142+ assert not any (isinstance (d ,fa .Replicated )for d in layout .lane_dims )
1143+ warp_split_factor = math .prod (
1144+ d .times if isinstance (d ,fa .Replicated )else 1
1145+ for d in layout .remove_dimension (1 ).warp_dims
1146+ )
1147+ lane_split_factor = math .prod (
1148+ d .times if isinstance (d ,fa .Replicated )else 1
1149+ for d in layout .remove_dimension (1 ).lane_dims
1150+ )
1151+ split_factor = warp_split_factor * lane_split_factor
11331152registers = _load_32xcols_native (
1134- self .address ,self .shape [1 ],self .dtype ,packing
1153+ self .address ,self .shape [1 ]// split_factor ,self .dtype ,packing ,packing
1154+ ).reshape (regs_shape )
1155+ # TODO(apaszke): Support the case where we have a long vector length in the
1156+ # FA more generally, not just for 2x32b.
1157+ # 16-bit types are special, because the store instruction can unpack them.
1158+ elif layout == TMEM_NATIVE_LAYOUT and has_default_layout and (
1159+ (bitwidth == 16 and packing == 1 )
1160+ or (bitwidth == 32 and layout .vector_length == 2 )
1161+ ):
1162+ registers = _load_32xcols_native (
1163+ self .address ,self .shape [1 ],self .dtype ,packing ,TMEM_NATIVE_LAYOUT .vector_length
11351164 ).reshape (regs_shape )
11361165elif layout == fa .WGMMA_LAYOUT and self .layout == tmem_half_lane_layout (self .shape [1 ],packing = packing ):
11371166# Load half the columns, since they are folded over lanes.
@@ -1157,8 +1186,6 @@ def load(self, layout: fa.TiledLayout | None = None, is_signed: bool | None = No
11571186 )
11581187
11591188def store (self ,value :fa .FragmentedArray ):
1160- if utils .bitwidth (self .dtype )not in {8 ,16 ,32 }:
1161- raise NotImplementedError (f"Unsupported dtype:{ self .dtype } " )
11621189if not isinstance (value ,fa .FragmentedArray ):
11631190raise TypeError (f"TMEM stores expect a FragmentedArray, got:{ value } " )
11641191if value .shape != self .shape :
@@ -1171,27 +1198,38 @@ def store(self, value: fa.FragmentedArray):
11711198f"Stored array has dtype{ value .mlir_dtype } , but TMEM has dtype"
11721199f"{ self .dtype } "
11731200 )
1201+ if not isinstance (value .layout ,fa .TiledLayout ):
1202+ raise TypeError (f"Stored array has layout{ value .layout } , but TMEM stores expect a TiledLayout" )
11741203packing = self .packing
11751204has_default_layout = self .layout == tmem_default_layout (packing = packing )
1205+ bitwidth = utils .bitwidth (self .dtype )
11761206if value .layout == LAYOUT and has_default_layout :
11771207_store_32xcols (
11781208self .address ,value .registers .T .reshape ((4 ,- 1 )),packing
11791209 )
1180- elif (
1181- utils .bitwidth (self .dtype )== 8
1182- and value .layout == fa .tmem_native_layout (vector_length = packing )
1183- and has_default_layout
1210+ elif value .layout == self .layout .as_tiled_layout ()and packing * bitwidth == 32 :
1211+ _store_32xcols_native (self .address ,value .registers .reshape (- 1 ),packing )
1212+ # TODO(apaszke): Support the case where we have a long vector length in the
1213+ # FA more generally, not just for 2x32b.
1214+ # TODO(apaszke): Support a wider range of layouts when dealing with unpacking.
1215+ # 16-bit types are special, because the store instruction can unpack them.
1216+ elif value .layout == TMEM_NATIVE_LAYOUT and has_default_layout and (
1217+ (bitwidth == 16 and packing == 1 )
1218+ or (bitwidth == 32 and value .layout .vector_length == 2 )
11841219 ):
11851220_store_32xcols_native (self .address ,value .registers .reshape (- 1 ),packing )
1186- elif value .layout == TMEM_NATIVE_LAYOUT and has_default_layout :
1187- _store_32xcols_native (
1188- self .address ,value .registers .reshape (- 1 ),packing
1189- )
1190- elif value .layout == fa .WGMMA_LAYOUT and self .layout == tmem_half_lane_layout (self .shape [1 ],packing = packing ):
1221+ elif (
1222+ value .layout == fa .WGMMA_LAYOUT
1223+ and self .layout == tmem_half_lane_layout (self .shape [1 ],packing = packing )
1224+ ):
11911225registers = value .registers .T .reshape (2 ,- 1 )
11921226registers = np .concatenate (np .split (registers ,2 ,axis = 1 ),axis = 0 )
11931227_store_32xcols (self .address ,registers ,packing )
1194- elif value .layout == fa_m64_collective_layout (self .shape [1 ])and self .layout == tmem_m64_collective_layout (self .shape [1 ],packing = packing ):
1228+ elif value .layout == fa_m64_collective_layout (
1229+ self .shape [1 ]
1230+ )and self .layout == tmem_m64_collective_layout (
1231+ self .shape [1 ],packing = packing
1232+ ):
11951233_store_32xcols (self .address ,value .registers .reshape (4 ,- 1 ),packing )
11961234else :
11971235raise ValueError (
@@ -1306,37 +1344,49 @@ def _store_32xcols(base_addr, vector_regs, tmem_packing) -> None:
13061344def _store_32xcols_native (base_addr ,vector_regs ,tmem_packing )-> None :
13071345i32 = ir .IntegerType .get_signless (32 )
13081346assert vector_regs .ndim == 1
1309- cols = len (vector_regs )* TMEM_NATIVE_LAYOUT .vector_length
13101347vec_ty = ir .VectorType (vector_regs .flat [0 ].type )
1311- reg_packing = 64 // utils .bitwidth (vec_ty )
1312- store_atom_shape = (32 ,reg_packing )
1348+ [vector_length ]= vec_ty .shape
13131349elt_bitwidth = utils .bitwidth (vec_ty .element_type )
1350+ reg_packing = 32 // elt_bitwidth
1351+ store_atom_shape = (32 ,reg_packing )
1352+ # TODO(apaszke): More general register splitting code, not just 2x32b.
13141353if reg_packing == 1 :
1315- # Transform data such that each reg is 32 bits wide.
1316- assert elt_bitwidth == 32 ,elt_bitwidth
1317- regs = [None ]* cols
1318- c0 = arith .constant (i32 ,0 )
1319- c1 = arith .constant (i32 ,1 )
1320- for idx ,vreg in enumerate (vector_regs ):
1321- regs [2 * idx ]= llvm .extractelement (vreg ,c0 )
1322- regs [2 * idx + 1 ]= llvm .extractelement (vreg ,c1 )
1354+ if vector_length == 2 :
1355+ # Transform data such that each reg is 32 bits wide.
1356+ regs = [None ]* (len (vector_regs )* 2 )
1357+ c0 = arith .constant (i32 ,0 )
1358+ c1 = arith .constant (i32 ,1 )
1359+ for idx ,vreg in enumerate (vector_regs ):
1360+ regs [2 * idx ]= llvm .extractelement (vreg ,c0 )
1361+ regs [2 * idx + 1 ]= llvm .extractelement (vreg ,c1 )
1362+ else :
1363+ regs = [utils .bitcast (r ,i32 )for r in vector_regs ]
13231364assert tmem_packing == 1
13241365unpack = False
13251366elif reg_packing == 2 :
1367+ assert vector_length == 2
13261368# In this case, registers are already packed into 32-bit registers.
1327- regs = vector_regs
1369+ regs = [ utils . bitcast ( r , i32 ) for r in vector_regs ]
13281370if elt_bitwidth == 16 :
13291371assert 1 <= tmem_packing <= 2
13301372unpack = tmem_packing == 1
13311373else :
1332- if tmem_packing == 1 :
1374+ if tmem_packing == 1 and elt_bitwidth != 32 :
13331375raise NotImplementedError (
13341376f"Unsupported packing:{ tmem_packing } for element type{ elt_bitwidth } "
13351377 )
13361378assert tmem_packing == 32 // elt_bitwidth
13371379unpack = False
13381380else :
1339- raise NotImplementedError (reg_packing )
1381+ if tmem_packing != reg_packing :
1382+ raise NotImplementedError (
1383+ f"Only{ reg_packing } packing supported for bitwidth{ elt_bitwidth } ,"
1384+ f" but got TMEM packing of{ tmem_packing } "
1385+ )
1386+ assert utils .bitwidth (vec_ty )== 32
1387+ regs = [utils .bitcast (r ,i32 )for r in vector_regs ]
1388+ unpack = False
1389+ cols = len (regs )* reg_packing
13401390it = _transfer_32xcols (base_addr ,cols ,store_atom_shape ,tmem_packing ,reg_packing )
13411391for addr_row_col ,instr_num ,lane_step ,num_slice in it :
13421392assert lane_step == 0
@@ -1393,21 +1443,23 @@ def _load_32xcols(base_addr, cols, dtype, tmem_packing) -> np.ndarray:
13931443return vector_regs
13941444
13951445
1396- def _load_32xcols_native (base_addr ,cols ,dtype ,tmem_packing )-> np .ndarray :
1446+ def _load_32xcols_native (base_addr ,cols ,dtype ,tmem_packing , vector_length )-> np .ndarray :
13971447i32 = ir .IntegerType .get_signless (32 )
1398- vec_ty = ir .VectorType .get ((2 ,),dtype )
1448+ vec_ty = ir .VectorType .get ((vector_length ,),dtype )
13991449reg_packing = 32 // utils .bitwidth (dtype )
1450+ assert vector_length % reg_packing == 0
14001451load_shape = "32x32b"
1401- if reg_packing == 1 :
1402- load_atom_shape = (32 ,1 )
1403- assert tmem_packing == 1
1404- pack = False
1405- elif reg_packing == 2 :
1406- load_atom_shape = (32 ,2 )
1452+ load_atom_shape = (32 ,reg_packing )
1453+ if reg_packing == 2 :
14071454assert 1 <= tmem_packing <= 2
14081455pack = tmem_packing == 1
14091456else :
1410- raise NotImplementedError (reg_packing )
1457+ if tmem_packing != reg_packing :
1458+ raise NotImplementedError (
1459+ f"Only{ reg_packing } supported for element type{ dtype } , but got"
1460+ f" TMEM packing of{ tmem_packing } "
1461+ )
1462+ pack = False
14111463
14121464it = _transfer_32xcols (base_addr ,cols ,load_atom_shape ,tmem_packing ,reg_packing )
14131465c0 = arith .constant (i32 ,0 )
@@ -1416,24 +1468,22 @@ def _load_32xcols_native(base_addr, cols, dtype, tmem_packing) -> np.ndarray:
14161468for addr_row_col ,instr_num ,lane_step ,num_slice in it :
14171469assert lane_step == 0 ,lane_step
14181470instr_regs = _tmem_load (addr_row_col ,load_shape ,instr_num ,pack )
1419- if reg_packing == 1 :
1471+ if reg_packing == 1 and vector_length == 2 :
14201472regs [num_slice ]= [llvm .bitcast (dtype ,r )for r in instr_regs ]
14211473else :
1422- assert reg_packing == 2
1423- regs [num_slice ]= [llvm .bitcast (vec_ty ,r )for r in instr_regs ]
1474+ regs [num_slice ]= [utils .bitcast (r ,vec_ty )for r in instr_regs ]
14241475
1425- if reg_packing == 1 :
1476+ if reg_packing == 1 and vector_length == 2 :
14261477vector_regs = np .ndarray ((cols // 2 ,),dtype = object )
14271478undef = llvm .mlir_undef (vec_ty )
14281479for idx in range (vector_regs .size ):
14291480high_undef = llvm .insertelement (undef ,regs [2 * idx ],c0 )
14301481vreg = llvm .insertelement (high_undef ,regs [2 * idx + 1 ],c1 )
14311482vector_regs [idx ]= vreg
14321483else :
1433- assert reg_packing == 2
1484+ assert vector_length == reg_packing
14341485vector_regs = np .asarray (regs ,dtype = object )
14351486
1436- assert vector_regs .shape == (cols // TMEM_NATIVE_LAYOUT .vector_length ,)
14371487return vector_regs
14381488
14391489