Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite254a02

Browse files
committed
Modified insertion of IGLP_OPT intrinsics
1 parent13b20fe commite254a02

File tree

8 files changed

+72
-21
lines changed

8 files changed

+72
-21
lines changed

‎third_party/amd/backend/compiler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,15 @@ def make_llir(src, metadata, options):
166166
## depends on the value of kernel arg `allow_flush_denorm`.
167167
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
168168
## For now it is used as a controller for developers only.
169+
sched_mode=""
170+
if"AMD_OPS_SCHED_MODE"inos.environ.keys():
171+
sched_mode=os.environ['AMD_OPS_SCHED_MODE']
172+
allowed= ["iglp-opt-0","iglp-opt-1",""]
173+
ifnotsched_modeinallowed:
174+
raiseRuntimeError(f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed:{", ".join(allowed)}')
175+
169176
__HIP_FTZ=True
170-
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ)
177+
amd.passes.ttgpuir.add_to_llvmir(pm,options.arch,__HIP_FTZ,sched_mode)
171178
passes.common.add_canonicalizer(pm)
172179
passes.common.add_cse(pm)
173180

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
2525
}// namespace AMD
2626

2727
std::unique_ptr<OperationPass<ModuleOp>>
28-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz);
28+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
29+
std::string schedMode);
2930
std::unique_ptr<OperationPass<ModuleOp>>createConvertBuiltinFuncToLLVMPass();
3031

3132
#defineGEN_PASS_REGISTRATION

‎third_party/amd/include/TritonAMDGPUToLLVM/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers
1515

1616
def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
1717
let summary = "Convert TritonGPU to LLVM";
18-
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
18+
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")";
1919

2020
let dependentDialects = ["mlir::arith::ArithDialect",
2121
"mlir::math::MathDialect",
@@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
3232
"gfx target device architecture, e.g., gfx942">,
3333
Option<"ftz", "ftz", "bool", /*default*/"true",
3434
"flush denorms for math functions">,
35+
Option<"sched", "sched", "std::string", /*default*/"\"\"",
36+
"scheduling variants">,
3537
];
3638
}
3739

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA;
99
namespacemlir::triton::AMD {
1010
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1111
const LLVMTypeConverter *typeConverter,
12-
ConversionPatternRewriter &rewriter);
12+
ConversionPatternRewriter &rewriter,
13+
StringRef schedMode);
1314

1415
LogicalResultconvertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1516
const LLVMTypeConverter *typeConverter,
@@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1819

1920
namespace {
2021
structDotOpConversion :publicConvertOpToLLVMPattern<triton::DotOp> {
21-
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
22+
// using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
23+
DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit,
24+
StringRef schedMode)
25+
: ConvertOpToLLVMPattern<triton::DotOp>(typeConverter, benefit),
26+
schedMode(schedMode) {}
2227

2328
LogicalResult
2429
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
@@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
3742
if (!isOuter) {
3843
auto dEncoding = cast<RankedTensorType>(D.getType()).getEncoding();
3944
if (isa<AMDMfmaEncodingAttr>(dEncoding) &&supportMFMA(op)) {
40-
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter);
45+
returnAMD::convertMFMA(op, adaptor,getTypeConverter(), rewriter,
46+
schedMode);
4147
}
4248
if (isa<AMDWmmaEncodingAttr>(dEncoding)) {
4349
returnAMD::convertWMMA(op, adaptor,getTypeConverter(), rewriter);
@@ -51,14 +57,17 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
5157
llvm::report_fatal_error(
5258
"Unsupported DotOp found when converting TritonGPU to LLVM.");
5359
}
60+
61+
private:
62+
StringRef schedMode;
5463
};
5564
}// namespace
5665

5766
namespacemlir::triton::AMD {
5867
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5968
RewritePatternSet &patterns,int numWarps,
6069
ModuleAxisInfoAnalysis &axisInfoAnalysis,
61-
PatternBenefit benefit) {
62-
patterns.add<DotOpConversion>(typeConverter, benefit);
70+
PatternBenefit benefit, StringRef schedMode) {
71+
patterns.add<DotOpConversion>(typeConverter, benefit, schedMode);
6372
}
6473
}// namespace mlir::triton::AMD

‎third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,25 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
3838

3939
using ValueTable = std::map<std::array<int,3>, Value>;
4040

41+
enumclassSchedulingOptionsEnum { IGLP_OPT_0 =0, IGLP_OPT_1 =1, NONE_SCHED };
42+
4143
structDotOpMFMAConversionHelper {
4244
AMDMfmaEncodingAttr mfmaLayout;
4345

4446
ConversionPatternRewriter &rewriter;
4547
const LLVMTypeConverter *typeConverter;
48+
SchedulingOptionsEnum schedMode;
4649
Location loc;
4750
MLIRContext *ctx{};
4851

4952
explicitDotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout,
5053
ConversionPatternRewriter &rewriter,
5154
const LLVMTypeConverter *typeConverter,
55+
SchedulingOptionsEnum schedMode,
5256
Location loc)
5357
: mfmaLayout(mfmaLayout), rewriter(rewriter),
54-
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
58+
typeConverter(typeConverter), schedMode(schedMode), loc(loc),
59+
ctx(mfmaLayout.getContext()) {}
5560

5661
ValuegetThreadId()const {
5762
auto llvmIndexTy = typeConverter->getIndexType();
@@ -70,6 +75,19 @@ struct DotOpMFMAConversionHelper {
7075
return rewriter.create(loweredOp)->getResult(0);
7176
}
7277

78+
voidgeneratedIglpIntrinsic()const {
79+
if (schedMode == SchedulingOptionsEnum::NONE_SCHED)
80+
return;
81+
auto intrinsicName =StringAttr::get(ctx,"llvm.amdgcn.iglp.opt");
82+
LLVM::FastmathFlagsAttr defaultFlags{};
83+
Typei32 = rewriter.getI32Type();
84+
85+
auto option = rewriter.create<LLVM::ConstantOp>(
86+
loc, rewriter.getIntegerAttr(i32,static_cast<int>(schedMode)));
87+
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
88+
ValueRange{option}, defaultFlags);
89+
}
90+
7391
intgetNumSubmatrices(Type elementType,intmDim,int nDim)const {
7492
if ((mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64))
7593
return1;
@@ -171,6 +189,8 @@ struct DotOpMFMAConversionHelper {
171189
assert((mDim == nDim && (mDim ==32 ||mDim ==16 ||mDim ==4)) ||
172190
(mDim ==64 && nDim ==4) || (mDim ==4 && nDim ==64));
173191

192+
generatedIglpIntrinsic();
193+
174194
Value a = op.getA();
175195
Value b = op.getB();
176196
Value d = op.getD();
@@ -351,13 +371,13 @@ struct DotOpMFMAConversionHelper {
351371
return dotOpVals;
352372
}
353373
};
354-
355374
}// namespace
356375

357376
namespacemlir::triton::AMD {
358377
LogicalResultconvertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
359378
const LLVMTypeConverter *typeConverter,
360-
ConversionPatternRewriter &rewriter) {
379+
ConversionPatternRewriter &rewriter,
380+
StringRef schedMode) {
361381
auto rankedTType = [](Value tensor) {
362382
return cast<RankedTensorType>(tensor.getType());
363383
};
@@ -375,11 +395,19 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
375395
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
376396
"DotOp's $c operand should pass the same number of values as $d");
377397

398+
staticconst DenseMap<StringRef, SchedulingOptionsEnum> schedModesToEnum = {
399+
{"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0},
400+
{"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1},
401+
{"", SchedulingOptionsEnum::NONE_SCHED}};
402+
assert(schedModesToEnum.contains(schedMode) &&
403+
"sched mode must be in the allowed set");
404+
378405
auto loc = op.getLoc();
379406
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(
380407
cast<RankedTensorType>(op.getResult().getType()).getEncoding());
381408

382-
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter, loc);
409+
DotOpMFMAConversionHelperhelper(mfmaLayout, rewriter, typeConverter,
410+
schedModesToEnum.at(schedMode), loc);
383411

384412
return helper.convertDot(op, adaptor);
385413
}

‎third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns(
1515
voidpopulateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
1616
RewritePatternSet &patterns,int numWarps,
1717
ModuleAxisInfoAnalysis &axisInfoAnalysis,
18-
PatternBenefit benefit);
18+
PatternBenefit benefit, StringRef schedMode);
1919
voidpopulateElementwiseOpToLLVMPatterns(
2020
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,bool ftz,
2121
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,

‎third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget {
6363
structConvertTritonAMDGPUToLLVM
6464
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
6565
ConvertTritonAMDGPUToLLVM> {
66-
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz) {
66+
explicitConvertTritonAMDGPUToLLVM(StringRef targetArch,bool ftz,
67+
StringRef schedMode) {
6768
this->arch = targetArch.str();
6869
this->ftz = ftz;
70+
this->sched = schedMode.str();
6971
}
7072

7173
voidgetDependentDialects(DialectRegistry &registry)constoverride {
@@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM
174176
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
175177
typeConverter, targetInfo, patterns, commonBenefit);
176178
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
177-
axisInfoAnalysis, AMDBenefit);
179+
axisInfoAnalysis, AMDBenefit, sched);
178180
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
179181
axisInfoAnalysis, allocation,
180182
targetInfo, AMDBenefit);
@@ -246,8 +248,10 @@ namespace mlir {
246248
namespacetriton {
247249

248250
std::unique_ptr<OperationPass<ModuleOp>>
249-
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz) {
250-
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
251+
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch,bool ftz,
252+
std::string schedMode) {
253+
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz,
254+
schedMode);
251255
}
252256

253257
}// namespace triton

‎third_party/amd/python/triton_amd.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ namespace py = pybind11;
3434
namespace {
3535
voidinit_triton_amd_passes_ttgpuir(py::module &&m) {
3636
usingnamespacemlir::triton;
37-
m.def("add_to_llvmir",
38-
[](mlir::PassManager &pm,const std::string &arch,bool ftz) {
39-
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
40-
});
37+
m.def("add_to_llvmir", [](mlir::PassManager &pm,const std::string &arch,
38+
bool ftz,const std::string &sched) {
39+
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz, sched));
40+
});
4141
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
4242
pm.addPass(createConvertBuiltinFuncToLLVMPass());
4343
});

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp