Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfbd6b4b

Browse files
committed
Modified insertion of IGLP_OPT intrinsics
1 parent13b20fe commitfbd6b4b

File tree

8 files changed

+119
-21
lines changed

8 files changed

+119
-21
lines changed

‎third_party/amd/backend/compiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,16 @@ def make_llir(src, metadata, options):
166166
## depends on the value of kernel arg `allow_flush_denorm`.
167167
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
168168
## For now it is used as a controller for developers only.
169+
sched_mode=""
170+
if"AMD_OPS_SCHED_MODE"inos.environ.keys():
171+
sched_mode=os.environ['AMD_OPS_SCHED_MODE']
172+
allowed= ["iglp-opt-0","iglp-opt-1","sched-barriers",""]
173+
ifnotsched_modeinallowed:
174+
raiseRuntimeError(
175+
f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed:{", ".join(allowed)}')
176+
169177
__HIP_FTZ=True
170-
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ)
178+
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ,sched_mode)
171179
passes.common.add_canonicalizer(pm)
172180
passes.common.add_cse(pm)
173181

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
2525
}// namespace AMD
2626

2727
std::unique_ptr<OperationPass<ModuleOp>>
28-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz);
28+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
29+
std::string schedMode);
2930
std::unique_ptr<OperationPass<ModuleOp>>createConvertBuiltinFuncToLLVMPass();
3031

3132
#defineGEN_PASS_REGISTRATION

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers
1515

1616
def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
1717
let summary = "Convert TritonGPU to LLVM";
18-
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
18+
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")";
1919

2020
let dependentDialects = ["mlir::arith::ArithDialect",
2121
"mlir::math::MathDialect",
@@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
3232
"gfx target device architecture, e.g., gfx942">,
3333
Option<"ftz", "ftz", "bool", /*default*/"true",
3434
"flush denorms for math functions">,
35+
Option<"sched", "sched", "std::string", /*default*/"\"\"",
36+
"scheduling variants">,
3537
];
3638
}
3739

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA;
99
namespacemlir::triton::AMD {
1010
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1111
const LLVMTypeConverter *typeConverter,
12-
ConversionPatternRewriter &rewriter);
12+
ConversionPatternRewriter &rewriter,
13+
StringRef schedMode);
1314

1415
LogicalResultconvertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1516
const LLVMTypeConverter *typeConverter,
@@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1819

1920
namespace {
2021
structDotOpConversion :publicConvertOpToLLVMPattern<triton::DotOp> {
21-
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
22+
// using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
23+
DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit,
24+
StringRef schedMode)
25+
: ConvertOpToLLVMPattern<triton::DotOp>(typeConverter, benefit),
26+
schedMode(schedMode) {}
2227

2328
LogicalResult
2429
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
@@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
3742
if (!isOuter) {
3843
auto dEncoding = cast<RankedTensorType>(D.getType()).getEncoding();
3944
if (isa<AMDMfmaEncodingAttr>(dEncoding) &&supportMFMA(op)) {
40-
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter);
45+
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter,
46+
schedMode);
4147
}
4248
if (isa<AMDWmmaEncodingAttr>(dEncoding)) {
4349
returnAMD::convertWMMA(op, adaptor,getTypeConverter(), rewriter);
@@ -51,14 +57,17 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
5157
llvm::report_fatal_error(
5258
"Unsupported DotOp found when converting TritonGPU to LLVM.");
5359
}
60+
61+
private:
62+
StringRef schedMode;
5463
};
5564
}// namespace
5665

5766
namespacemlir::triton::AMD {
5867
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5968
RewritePatternSet &patterns,int numWarps,
6069
ModuleAxisInfoAnalysis &axisInfoAnalysis,
61-
PatternBenefit benefit) {
62-
patterns.add<DotOpConversion>(typeConverter, benefit);
70+
PatternBenefit benefit, StringRef schedMode) {
71+
patterns.add<DotOpConversion>(typeConverter, benefit, schedMode);
6372
}
6473
}// namespace mlir::triton::AMD

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
3838

3939
using ValueTable = std::map<std::array<int,3>, Value>;
4040

41+
enumclassSchedulingOptionsEnum :int64_t {
42+
IGLP_OPT_0 =0,
43+
IGLP_OPT_1 =1,
44+
SCHED_BARRIERS,
45+
NONE_SCHED
46+
};
47+
enumclassInstructionMaskEnum :int64_t {
48+
VALU =0x00000002,
49+
SALU =0x00000004,
50+
MFMA =0x00000008,
51+
ALL_VMEM =0x00000010,
52+
VMEM_READ =0x00000020,
53+
VMEM_WRITE =0x00000040,
54+
ALL_DS =0x00000080,
55+
DS_READ =0x00000100,
56+
DS_WRITE =0x00000200
57+
};
58+
4159
structDotOpMFMAConversionHelper {
4260
AMDMfmaEncodingAttr mfmaLayout;
4361

4462
ConversionPatternRewriter &rewriter;
4563
const LLVMTypeConverter *typeConverter;
64+
SchedulingOptionsEnum schedMode;
4665
Location loc;
4766
MLIRContext *ctx{};
4867

4968
explicitDotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout,
5069
ConversionPatternRewriter &rewriter,
5170
const LLVMTypeConverter *typeConverter,
71+
SchedulingOptionsEnum schedMode,
5272
Location loc)
5373
: mfmaLayout(mfmaLayout), rewriter(rewriter),
54-
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
74+
typeConverter(typeConverter), schedMode(schedMode), loc(loc),
75+
ctx(mfmaLayout.getContext()) {}
5576

5677
ValuegetThreadId()const {
5778
auto llvmIndexTy = typeConverter->getIndexType();
@@ -70,6 +91,45 @@ struct DotOpMFMAConversionHelper {
7091
return rewriter.create(loweredOp)->getResult(0);
7192
}
7293

94+
voidgeneratedIglpIntrinsic()const {
95+
if (!((schedMode == SchedulingOptionsEnum::IGLP_OPT_0) ||
96+
(schedMode == SchedulingOptionsEnum::IGLP_OPT_1))) {
97+
return;
98+
}
99+
auto intrinsicName =StringAttr::get(ctx,"llvm.amdgcn.iglp.opt");
100+
LLVM::FastmathFlagsAttr defaultFlags{};
101+
Typei32 = rewriter.getI32Type();
102+
103+
auto option = rewriter.create<LLVM::ConstantOp>(
104+
loc, rewriter.getIntegerAttr(i32,static_cast<int>(schedMode)));
105+
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
106+
ValueRange{option}, defaultFlags);
107+
}
108+
109+
voidbuildSchedGroupBarrier(InstructionMaskEnum maskValue,int sizeValue,
110+
int groupIdValue)const {
111+
auto intrinsicName =
112+
StringAttr::get(ctx,"llvm.amdgcn.sched.group.barrier");
113+
LLVM::FastmathFlagsAttr defaultFlags{};
114+
Typei32 = rewriter.getI32Type();
115+
auto mask = rewriter.create<LLVM::ConstantOp>(
116+
loc, rewriter.getIntegerAttr(i32,static_cast<int64_t>(maskValue)));
117+
auto size = rewriter.create<LLVM::ConstantOp>(
118+
loc, rewriter.getIntegerAttr(i32, sizeValue));
119+
auto groupId = rewriter.create<LLVM::ConstantOp>(
120+
loc, rewriter.getIntegerAttr(i32, groupIdValue));
121+
122+
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
123+
ValueRange{mask, size, groupId},
124+
defaultFlags);
125+
}
126+
127+
voidinsertSchedBarriers()const {
128+
if (!(schedMode == SchedulingOptionsEnum::SCHED_BARRIERS))
129+
return;
130+
// TODO(ravil)
131+
}
132+
73133
intgetNumSubmatrices(Type elementType,intmDim,int nDim)const {
74134
if ((mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64))
75135
return1;
@@ -171,6 +231,8 @@ struct DotOpMFMAConversionHelper {
171231
assert((mDim == nDim && (mDim ==32 ||mDim ==16 ||mDim ==4)) ||
172232
(mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64));
173233

234+
generatedIglpIntrinsic();
235+
174236
Value a = op.getA();
175237
Value b = op.getB();
176238
Value d = op.getD();
@@ -263,6 +325,9 @@ struct DotOpMFMAConversionHelper {
263325
Type structTy =LLVM::LLVMStructType::getLiteral(
264326
ctx, SmallVector<Type>(fc.size(), dstElemTy));
265327
Value res =packLLElements(loc, typeConverter, fc, rewriter, structTy);
328+
329+
insertSchedBarriers();
330+
266331
rewriter.replaceOp(op, res);
267332

268333
returnsuccess();
@@ -351,13 +416,13 @@ struct DotOpMFMAConversionHelper {
351416
return dotOpVals;
352417
}
353418
};
354-
355419
}// namespace
356420

357421
namespacemlir::triton::AMD {
358422
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
359423
const LLVMTypeConverter *typeConverter,
360-
ConversionPatternRewriter &rewriter) {
424+
ConversionPatternRewriter &rewriter,
425+
StringRef schedMode) {
361426
auto rankedTType = [](Value tensor) {
362427
return cast<RankedTensorType>(tensor.getType());
363428
};
@@ -375,11 +440,20 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
375440
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
376441
"DotOp's $c operand should pass the same number of values as $d");
377442

443+
staticconst DenseMap<StringRef, SchedulingOptionsEnum> schedModesToEnum = {
444+
{"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0},
445+
{"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1},
446+
{"sched-barriers", SchedulingOptionsEnum::SCHED_BARRIERS},
447+
{"", SchedulingOptionsEnum::NONE_SCHED}};
448+
assert(schedModesToEnum.contains(schedMode) &&
449+
"sched mode must be in the allowed set");
450+
378451
auto loc = op.getLoc();
379452
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(
380453
cast<RankedTensorType>(op.getResult().getType()).getEncoding());
381454

382-
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter, loc);
455+
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter,
456+
schedModesToEnum.at(schedMode), loc);
383457

384458
return helper.convertDot(op, adaptor);
385459
}

‎third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns(
1515
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
1616
RewritePatternSet &patterns,int numWarps,
1717
ModuleAxisInfoAnalysis &axisInfoAnalysis,
18-
PatternBenefit benefit);
18+
PatternBenefit benefit, StringRef schedMode);
1919
voidpopulateElementwiseOpToLLVMPatterns(
2020
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,bool ftz,
2121
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,

‎third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget {
6363
structConvertTritonAMDGPUToLLVM
6464
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
6565
ConvertTritonAMDGPUToLLVM> {
66-
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz) {
66+
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz,
67+
StringRef schedMode) {
6768
this->arch = targetArch.str();
6869
this->ftz = ftz;
70+
this->sched = schedMode.str();
6971
}
7072

7173
voidgetDependentDialects(DialectRegistry &registry)constoverride {
@@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM
174176
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
175177
typeConverter, targetInfo, patterns, commonBenefit);
176178
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
177-
axisInfoAnalysis, AMDBenefit);
179+
axisInfoAnalysis, AMDBenefit, sched);
178180
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
179181
axisInfoAnalysis, allocation,
180182
targetInfo, AMDBenefit);
@@ -246,8 +248,10 @@ namespace mlir {
246248
namespacetriton {
247249

248250
std::unique_ptr<OperationPass<ModuleOp>>
249-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz) {
250-
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
251+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
252+
std::string schedMode) {
253+
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz,
254+
schedMode);
251255
}
252256

253257
}// namespace triton

‎third_party/amd/python/triton_amd.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ namespace py = pybind11;
3434
namespace {
3535
voidinit_triton_amd_passes_ttgpuir(py::module &&m) {
3636
usingnamespacemlir::triton;
37-
m.def("add_to_llvmir",
38-
[](mlir::PassManager &pm,const std::string &arch,bool ftz) {
39-
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
40-
});
37+
m.def("add_to_llvmir", [](mlir::PassManager &pm,const std::string &arch,
38+
bool ftz,const std::string &sched) {
39+
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz, sched));
40+
});
4141
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
4242
pm.addPass(createConvertBuiltinFuncToLLVMPass());
4343
});

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp