Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1abb7b4

Browse files
committed
Modified insertion of IGLP_OPT intrinsics
1 parent13b20fe commit1abb7b4

File tree

8 files changed

+75
-21
lines changed

8 files changed

+75
-21
lines changed

‎third_party/amd/backend/compiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,16 @@ def make_llir(src, metadata, options):
166166
## depends on the value of kernel arg `allow_flush_denorm`.
167167
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
168168
## For now it is used as a controller for developers only.
169+
sched_mode=""
170+
if"AMD_OPS_SCHED_MODE"inos.environ.keys():
171+
sched_mode=os.environ['AMD_OPS_SCHED_MODE']
172+
allowed= ["iglp-opt-0","iglp-opt-1",""]
173+
ifnotsched_modeinallowed:
174+
raiseRuntimeError(
175+
f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed:{", ".join(allowed)}')
176+
169177
__HIP_FTZ=True
170-
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ)
178+
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ,sched_mode)
171179
passes.common.add_canonicalizer(pm)
172180
passes.common.add_cse(pm)
173181

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
2525
}// namespace AMD
2626

2727
std::unique_ptr<OperationPass<ModuleOp>>
28-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz);
28+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
29+
std::string schedMode);
2930
std::unique_ptr<OperationPass<ModuleOp>>createConvertBuiltinFuncToLLVMPass();
3031

3132
#defineGEN_PASS_REGISTRATION

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers
1515

1616
def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
1717
let summary = "Convert TritonGPU to LLVM";
18-
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
18+
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")";
1919

2020
let dependentDialects = ["mlir::arith::ArithDialect",
2121
"mlir::math::MathDialect",
@@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
3232
"gfx target device architecture, e.g., gfx942">,
3333
Option<"ftz", "ftz", "bool", /*default*/"true",
3434
"flush denorms for math functions">,
35+
Option<"sched", "sched", "std::string", /*default*/"\"\"",
36+
"scheduling variants">,
3537
];
3638
}
3739

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA;
99
namespacemlir::triton::AMD {
1010
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1111
const LLVMTypeConverter *typeConverter,
12-
ConversionPatternRewriter &rewriter);
12+
ConversionPatternRewriter &rewriter,
13+
StringRef schedMode);
1314

1415
LogicalResultconvertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1516
const LLVMTypeConverter *typeConverter,
@@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1819

1920
namespace {
2021
structDotOpConversion :publicConvertOpToLLVMPattern<triton::DotOp> {
21-
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
22+
// using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
23+
DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit,
24+
StringRef schedMode)
25+
: ConvertOpToLLVMPattern<triton::DotOp>(typeConverter, benefit),
26+
schedMode(schedMode) {}
2227

2328
LogicalResult
2429
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
@@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
3742
if (!isOuter) {
3843
auto dEncoding = cast<RankedTensorType>(D.getType()).getEncoding();
3944
if (isa<AMDMfmaEncodingAttr>(dEncoding) &&supportMFMA(op)) {
40-
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter);
45+
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter,
46+
schedMode);
4147
}
4248
if (isa<AMDWmmaEncodingAttr>(dEncoding)) {
4349
returnAMD::convertWMMA(op, adaptor,getTypeConverter(), rewriter);
@@ -51,14 +57,17 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
5157
llvm::report_fatal_error(
5258
"Unsupported DotOp found when converting TritonGPU to LLVM.");
5359
}
60+
61+
private:
62+
StringRef schedMode;
5463
};
5564
}// namespace
5665

5766
namespacemlir::triton::AMD {
5867
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5968
RewritePatternSet &patterns,int numWarps,
6069
ModuleAxisInfoAnalysis &axisInfoAnalysis,
61-
PatternBenefit benefit) {
62-
patterns.add<DotOpConversion>(typeConverter, benefit);
70+
PatternBenefit benefit, StringRef schedMode) {
71+
patterns.add<DotOpConversion>(typeConverter, benefit, schedMode);
6372
}
6473
}// namespace mlir::triton::AMD

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,25 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
3838

3939
using ValueTable = std::map<std::array<int,3>, Value>;
4040

41+
enumclassSchedulingOptionsEnum { IGLP_OPT_0 =0, IGLP_OPT_1 =1, NONE_SCHED };
42+
4143
structDotOpMFMAConversionHelper {
4244
AMDMfmaEncodingAttr mfmaLayout;
4345

4446
ConversionPatternRewriter &rewriter;
4547
const LLVMTypeConverter *typeConverter;
48+
SchedulingOptionsEnum schedMode;
4649
Location loc;
4750
MLIRContext *ctx{};
4851

4952
explicitDotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout,
5053
ConversionPatternRewriter &rewriter,
5154
const LLVMTypeConverter *typeConverter,
55+
SchedulingOptionsEnum schedMode,
5256
Location loc)
5357
: mfmaLayout(mfmaLayout), rewriter(rewriter),
54-
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
58+
typeConverter(typeConverter), schedMode(schedMode), loc(loc),
59+
ctx(mfmaLayout.getContext()) {}
5560

5661
ValuegetThreadId()const {
5762
auto llvmIndexTy = typeConverter->getIndexType();
@@ -70,6 +75,21 @@ struct DotOpMFMAConversionHelper {
7075
return rewriter.create(loweredOp)->getResult(0);
7176
}
7277

78+
voidgeneratedIglpIntrinsic()const {
79+
if (!((schedMode == SchedulingOptionsEnum::IGLP_OPT_0) ||
80+
(schedMode == SchedulingOptionsEnum::IGLP_OPT_1))) {
81+
return;
82+
}
83+
auto intrinsicName =StringAttr::get(ctx,"llvm.amdgcn.iglp.opt");
84+
LLVM::FastmathFlagsAttr defaultFlags{};
85+
Typei32 = rewriter.getI32Type();
86+
87+
auto option = rewriter.create<LLVM::ConstantOp>(
88+
loc, rewriter.getIntegerAttr(i32,static_cast<int>(schedMode)));
89+
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
90+
ValueRange{option}, defaultFlags);
91+
}
92+
7393
intgetNumSubmatrices(Type elementType,intmDim,int nDim)const {
7494
if ((mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64))
7595
return1;
@@ -171,6 +191,8 @@ struct DotOpMFMAConversionHelper {
171191
assert((mDim == nDim && (mDim ==32 ||mDim ==16 ||mDim ==4)) ||
172192
(mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64));
173193

194+
generatedIglpIntrinsic();
195+
174196
Value a = op.getA();
175197
Value b = op.getB();
176198
Value d = op.getD();
@@ -351,13 +373,13 @@ struct DotOpMFMAConversionHelper {
351373
return dotOpVals;
352374
}
353375
};
354-
355376
}// namespace
356377

357378
namespacemlir::triton::AMD {
358379
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
359380
const LLVMTypeConverter *typeConverter,
360-
ConversionPatternRewriter &rewriter) {
381+
ConversionPatternRewriter &rewriter,
382+
StringRef schedMode) {
361383
auto rankedTType = [](Value tensor) {
362384
return cast<RankedTensorType>(tensor.getType());
363385
};
@@ -375,11 +397,19 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
375397
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
376398
"DotOp's $c operand should pass the same number of values as $d");
377399

400+
staticconst DenseMap<StringRef, SchedulingOptionsEnum> schedModesToEnum = {
401+
{"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0},
402+
{"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1},
403+
{"", SchedulingOptionsEnum::NONE_SCHED}};
404+
assert(schedModesToEnum.contains(schedMode) &&
405+
"sched mode must be in the allowed set");
406+
378407
auto loc = op.getLoc();
379408
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(
380409
cast<RankedTensorType>(op.getResult().getType()).getEncoding());
381410

382-
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter, loc);
411+
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter,
412+
schedModesToEnum.at(schedMode), loc);
383413

384414
return helper.convertDot(op, adaptor);
385415
}

‎third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns(
1515
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
1616
RewritePatternSet &patterns,int numWarps,
1717
ModuleAxisInfoAnalysis &axisInfoAnalysis,
18-
PatternBenefit benefit);
18+
PatternBenefit benefit, StringRef schedMode);
1919
voidpopulateElementwiseOpToLLVMPatterns(
2020
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,bool ftz,
2121
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,

‎third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget {
6363
structConvertTritonAMDGPUToLLVM
6464
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
6565
ConvertTritonAMDGPUToLLVM> {
66-
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz) {
66+
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz,
67+
StringRef schedMode) {
6768
this->arch = targetArch.str();
6869
this->ftz = ftz;
70+
this->sched = schedMode.str();
6971
}
7072

7173
voidgetDependentDialects(DialectRegistry &registry)constoverride {
@@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM
174176
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
175177
typeConverter, targetInfo, patterns, commonBenefit);
176178
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
177-
axisInfoAnalysis, AMDBenefit);
179+
axisInfoAnalysis, AMDBenefit, sched);
178180
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
179181
axisInfoAnalysis, allocation,
180182
targetInfo, AMDBenefit);
@@ -246,8 +248,10 @@ namespace mlir {
246248
namespacetriton {
247249

248250
std::unique_ptr<OperationPass<ModuleOp>>
249-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz) {
250-
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
251+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
252+
std::string schedMode) {
253+
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz,
254+
schedMode);
251255
}
252256

253257
}// namespace triton

‎third_party/amd/python/triton_amd.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ namespace py = pybind11;
3434
namespace {
3535
voidinit_triton_amd_passes_ttgpuir(py::module &&m) {
3636
usingnamespacemlir::triton;
37-
m.def("add_to_llvmir",
38-
[](mlir::PassManager &pm,const std::string &arch,bool ftz) {
39-
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
40-
});
37+
m.def("add_to_llvmir", [](mlir::PassManager &pm,const std::string &arch,
38+
bool ftz,const std::string &sched) {
39+
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz, sched));
40+
});
4141
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
4242
pm.addPass(createConvertBuiltinFuncToLLVMPass());
4343
});

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp