Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7f5d5f1

Browse files
committed
Adding support for refining ElementWise, ExpandDims and Broadcast
1 parent96060c0 commit7f5d5f1

File tree

2 files changed

+459
-3
lines changed

2 files changed

+459
-3
lines changed

‎lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,28 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
394394
// For the lane (i.e., thread) dimension, these threads are along the
395395
// matrix C's N dimension, with 32 consecutive threads covering a whole
396396
// row and the next 32 threads start after a gap spanning 4 rows.
397+
std::vector<std::vector<int>> regBases = { {0,1}, {0,2} };
398+
if (getWarpsPerCTA()[1] >1) {
399+
regBases.push_back({0,8});
400+
regBases.push_back({0,16});
401+
}
397402
tileLayout =LinearLayout(
398-
{{kRegister,{{0,1}, {0,2}, {0,8},/*gap*/ {0,16}}},
403+
{{kRegister,regBases},
399404
{kLane, {{1,0}, {2,0}, {4,0}, {8,0}, {16,0},/*gap*/ {0,4}}}},
400405
{outDimNames[order[0]], outDimNames[order[1]]});
401406
// For mfma.transposed layout, the element ownership among threads are
402407
// "transposed" within each warp.
403-
if (getIsTransposed())
408+
if (getIsTransposed()) {
409+
regBases = { {1,0}, {2,0} };
410+
if (getWarpsPerCTA()[1] >1) {
411+
regBases.push_back({8,0});
412+
regBases.push_back({16,0});
413+
}
404414
tileLayout =LinearLayout(
405-
{{kRegister,{{1,0}, {2,0}, {8,0},/*gap*/ {16,0}}},
415+
{{kRegister,regBases},
406416
{kLane, {{0,1}, {0,2}, {0,4}, {0,8}, {0,16},/*gap*/ {4,0}}}},
407417
{outDimNames[order[0]], outDimNames[order[1]]});
418+
}
408419
}else {
409420
assert(getMDim() ==16);
410421
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp