Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit17614f7

Browse files
Add the SubgroupMatrixMultiply shader
1 parent3d6e51c commit17614f7

File tree

1 file changed

+182
-43
lines changed

1 file changed

+182
-43
lines changed

‎examples/matmul/run.cpp‎

Lines changed: 182 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613
return {unrolledCode, workgroupSize, precision};
614614
}
615615

616+
inline KernelCodecreateMatmul12(constchar *shaderTemplate,constsize_t M,
617+
constsize_t K,constsize_t N,
618+
constsize_t TM,constsize_t TN,
619+
constsize_t LID,
620+
const Shape &workgroupSize = {256,1,1},
621+
NumType precision = kf32) {
622+
std::stringcodeString(shaderTemplate);
623+
replaceAll(codeString, {{"{{precision}}",toString(precision)},
624+
{"{{M}}",toString(M)},
625+
{"{{K}}",toString(K)},
626+
{"{{N}}",toString(N)},
627+
{"{{TM}}",toString(TM)},
628+
{"{{TN}}",toString(TN)},
629+
{"{{LID}}",toString(LID)}
630+
});
631+
return {loopUnrolling(codeString), workgroupSize, precision};
632+
}
633+
634+
// ─────────────────────────────────────────────────────────────────────────────
635+
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
636+
// and subgroupMatrixMultiplyAccumulate
637+
// ─────────────────────────────────────────────────────────────────────────────
638+
constchar*kShaderSubgroupMatrixMultiply =R"(
639+
enable chromium_experimental_subgroup_matrix;
640+
diagnostic (off, chromium.subgroup_matrix_uniformity);
641+
642+
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+
@compute @workgroup_size({{workgroupSize}})
647+
fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+
@builtin(local_invocation_id) localID : vec3<u32>) {
649+
650+
let rowStart: u32 = wg.x * 8u * {{TM}};
651+
let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
652+
653+
let baseA: u32 = rowStart * {{K}};
654+
let baseB: u32 = colStart;
655+
let cBase: u32 = rowStart * {{N}} + colStart;
656+
657+
var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
658+
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
659+
660+
// 4x4 accumulators (8x8 each)
661+
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
662+
663+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
664+
Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
665+
}
666+
667+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
668+
Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
669+
}
670+
671+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
672+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
673+
accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
674+
}
675+
}
676+
677+
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
678+
workgroupBarrier();
679+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
680+
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
681+
}
682+
683+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
684+
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
685+
}
686+
687+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
688+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
689+
accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
690+
}
691+
}
692+
}
693+
694+
workgroupBarrier();
695+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
696+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
697+
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
698+
}
699+
}
700+
}
701+
)";
702+
616703
/**
617704
* @brief No-Op shader with matmul bindings for performance testing
618705
*/
@@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
683770
const Bindings</* input, weights, output*/3> &bindings,
684771
size_t M,size_t K,size_t N, NumType numtype) {
685772
Kernel kernel;
773+
CompilationInfo info;
686774
if (version ==1) {
687775
Shape wgSize = {256,1,1};
688776
Shape nWorkgroups =cdiv({M, N,1}, {16,16,1});
689777
KernelCode matmul =createNoOp(kShaderNoOp,/*wgsize*/ wgSize);
690778
kernel =createKernel(ctx, matmul, bindings,
691-
/*nWorkgroups*/ nWorkgroups);
779+
/*nWorkgroups*/ nWorkgroups,
780+
NoParam{}, &info);
692781
}elseif (version ==2) {
693782
Shape wgSize = {16,16,1};
694783
LOG(kDefLog,kInfo,"wgSize: %s",toString(wgSize).c_str());
695784
KernelCode matmul =
696785
createMatmul1(kShaderMatmul1, M, K, N,/*wgsize*/ wgSize, numtype);
697786
kernel =createKernel(ctx, matmul, bindings,
698-
/*nWorkgroups*/cdiv({M, N,1}, wgSize));
787+
/*nWorkgroups*/cdiv({M, N,1}, wgSize),
788+
NoParam{}, &info);
699789
}elseif (version ==3) {
700790
staticconstexprsize_t tileSize =16;
701791
KernelCode matmul =createMatmul2(kShaderMatmul2, M, K, N,
702792
/*wgSize*/ {tileSize * tileSize,1,1}, numtype);
703793
kernel =
704794
createKernel(ctx, matmul, bindings,
705-
/* nWorkgroups*/cdiv({M, N,1}, {tileSize, tileSize,1}));
795+
/* nWorkgroups*/cdiv({M, N,1}, {tileSize, tileSize,1}),
796+
NoParam{}, &info);
706797
}elseif (version ==4 || version ==6) {
707798
staticconstexprsize_t BM =64;
708799
staticconstexprsize_t BK =4;
@@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
721812
numtype,
722813
/*Loop unrolling*/ version ==6 ?true:false);
723814
kernel =createKernel(ctx, matmul, bindings,
724-
/*nWorkgroups*/ nWorkgroups);
815+
/*nWorkgroups*/ nWorkgroups,
816+
NoParam{}, &info);
725817
}elseif (version ==5 || version ==7) {
726818
staticconstexprsize_t BM =64;
727819
staticconstexprsize_t BK =8;
@@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
739831
numtype,
740832
/*Loop unrolling*/ version ==7 ?true:false);
741833
kernel =createKernel(ctx, matmul, bindings,
742-
/*nWorkgroups*/ nWorkgroups);
834+
/*nWorkgroups*/ nWorkgroups,
835+
NoParam{}, &info);
743836
}elseif (version ==8 || version ==10) {
744837
staticconstexprsize_t BM =64;
745838
staticconstexprsize_t BK =8;
@@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
757850
numtype,
758851
/*Loop unrolling*/true);
759852
kernel =createKernel(ctx, matmul, bindings,
760-
/*nWorkgroups*/ nWorkgroups);
853+
/*nWorkgroups*/ nWorkgroups,
854+
NoParam{}, &info);
761855
}elseif (version ==9 || version ==11) {
762856
staticconstexprsize_t BM =64;
763857
staticconstexprsize_t BK =8;
@@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
774868
/*wgSize*/ wgSize,
775869
numtype);
776870
kernel =createKernel(ctx, matmul, bindings,
777-
/*nWorkgroups*/ nWorkgroups);
871+
/*nWorkgroups*/ nWorkgroups,
872+
NoParam{}, &info);
873+
}elseif (version ==12 || version ==13) {
874+
// f16: Subgroup matrix multiply
875+
staticconstexprsize_t TM =4;
876+
staticconstexprsize_t TN =8;
877+
staticconstexprsize_t LID =2;
878+
Shape wgSize = {32, LID,1};// One subgroup per workgroup
879+
Shape nWorkgroups = {cdiv(M,8 * TM),cdiv(N,8 * TN * LID),1};
880+
LOG(kDefLog,kInfo,"M: %zu, K: %zu, N: %zu", M, K, N);
881+
LOG(kDefLog,kInfo,"wgSize: ( %s )",toString(wgSize).c_str());
882+
LOG(kDefLog,kInfo,"nWorkgroups: ( %s )",toString(nWorkgroups).c_str());
883+
KernelCode matmul =createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, LID, wgSize, numtype);
884+
kernel =createKernel(ctx, matmul, bindings, nWorkgroups,
885+
NoParam{}, &info);
886+
}
887+
888+
if (info.status != WGPUCompilationInfoRequestStatus_Success) {
889+
LOG(kDefLog,kError,"Failed to compile shader");
890+
for (size_t i =0; i < info.messages.size(); i++) {
891+
LOG(kDefLog,kError,"Line %llu, Pos %llu: %s", info.lineNums[i],
892+
info.linePos[i], info.messages[i].c_str());
893+
}
894+
exit(1);
895+
}else {
896+
LOG(kDefLog,kInfo,"Shader compiled successfully");
897+
for (size_t i =0; i < info.messages.size(); i++) {
898+
LOG(kDefLog,kInfo,"Line %llu, Pos %llu: %s", info.lineNums[i],
899+
info.linePos[i], info.messages[i].c_str());
900+
}
778901
}
902+
779903
return kernel;
780904
}
781905

@@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
791915
assert(numtype == kf16);
792916
}
793917

794-
// Allocate GPU buffers and copy data
795-
WGPUDeviceDescriptor devDescriptor = {};
796-
devDescriptor.requiredFeatureCount =1;
797-
devDescriptor.requiredFeatures =std::array{WGPUFeatureName_ShaderF16}.data();
798-
799-
Context ctx;
800-
if (numtype == kf16) {
801-
ctx =createContext(
802-
{}, {},
803-
/*device descriptor, enabling f16 in WGSL*/
804-
{
805-
.requiredFeatureCount =1,
806-
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
807-
});
808-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
809-
LOG(kDefLog,kError,"Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
810-
exit(1);
918+
static WGPUDawnTogglesDescriptor toggles = {};
919+
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
920+
constchar* enableList[] = {"allow_unsafe_apis"};
921+
toggles.enabledToggles =enableList;
922+
toggles.enabledToggleCount =1;
923+
924+
static WGPUDeviceDescriptor devDesc = {};
925+
devDesc.nextInChain = &toggles.chain;
926+
devDesc.requiredFeatureCount =3,
927+
devDesc.requiredFeatures = std::array{
928+
WGPUFeatureName_ShaderF16,
929+
WGPUFeatureName_Subgroups,
930+
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
931+
}.data();
932+
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
933+
.callback = [](WGPUDeviceconst * device, WGPUErrorType type, WGPUStringView msg,void*,void*) {
934+
LOG(kDefLog,kError,"[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
811935
}
812-
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813-
LOG(kDefLog,kError,"Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
814-
exit(1);
936+
};
937+
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
938+
.mode = WGPUCallbackMode_AllowSpontaneous,
939+
.callback = [](WGPUDeviceconst * device, WGPUDeviceLostReason reason, WGPUStringView msg,void*,void*) {
940+
LOG(kDefLog,kError,"[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
815941
}
816-
}
942+
};
817943

818-
if (numtype == kf32) {
819-
ctx =createContext({}, {}, {});
820-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821-
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822-
LOG(kDefLog,kError,"Failed to create adapter or device");
823-
// stop execution
824-
exit(1);
825-
}else {
826-
LOG(kDefLog,kInfo,"Successfully created adapter and device");
944+
static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
945+
devDesc.requiredLimits = &requiredLimits;
946+
Context ctx =createContext({}, {}, devDesc);
947+
948+
WGPULoggingCallbackInfo logCb{
949+
.callback = [](WGPULoggingType type, WGPUStringView msg,void*,void*) {
950+
LOG(kDefLog,kError,"[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
827951
}
828-
}
952+
};
953+
wgpuDeviceSetLoggingCallback(ctx.device, logCb);
954+
955+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
956+
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
957+
LOG(kDefLog,kError,"Failed to create adapter or device");
958+
// stop execution
959+
exit(1);
960+
}else {
961+
LOG(kDefLog,kInfo,"Successfully created adapter and device");
962+
}
829963

830964
Tensor input =createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
831965
Tensor weights =createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get());// column-major
@@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
859993
// Use microsecond for more accurate time measurement
860994
auto duration =
861995
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
862-
float gflops =2 * M * N *
996+
float gflops =2.0f * M * N *
863997
K /// factor of 2 for multiplication & accumulation
864998
(static_cast<double>(duration.count()) /1000000.0) /
865999
1000000000.0 *static_cast<float>(nIter);
@@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
8701004
show<precision>(outputPtr.get(), M, N,"Output[0]").c_str());
8711005

8721006
LOG(kDefLog,kInfo,"\n\n===================================================================="
873-
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations"
1007+
"============\nExecution Time: (M = %zu, K = %zu, N = %zu) x %zu iterations"
8741008
":\n%.1f"
8751009
"milliseconds / dispatch ~ %.2f"
8761010
"GFLOPS\n================================================================"
@@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
9131047
case9:return"f32: 2D blocktiling with loop unrolling, vectorization and transpose";
9141048
case10:return"f16: 2D blocktiling with loop unrolling and vectorization";
9151049
case11:return"f16: 2D blocktiling with loop unrolling, vectorization and transpose";
1050+
case12:return"f16: Subgroup matrix multiply with transpose (default)";
1051+
case13:return"f32: Subgroup matrix multiply with transpose";
9161052
default:return"Not specified";
9171053
}
9181054
}
9191055

9201056
intmain() {
1057+
std::cout <<"Starting matmul test..." << std::endl;
9211058
char* version_str =getenv("MATMUL_VERSION");
922-
int version = version_str ==NULL ?10 :atoi(version_str);
1059+
int version = version_str ==NULL ?12 :atoi(version_str);
9231060
// 1 == f32: No-Op
9241061
// 2 == f32: naive matmul
9251062
// 3 == f32: tiling
@@ -931,8 +1068,10 @@ int main() {
9311068
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9321069
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9331070
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
934-
bool enableF16 = version ==10 || version ==11;
935-
bool transposedInput = version ==9 || version ==11;
1071+
// 12 == f16: Subgroup matrix multiply with transpose (default)
1072+
// 13 == f32: Subgroup matrix multiply with transpose
1073+
bool enableF16 = version ==10 || version ==11 || version ==12;
1074+
bool transposedInput = version ==9 || version ==11 || version ==12 || version ==13;
9361075
NumType numtype = enableF16 ? kf16 : kf32;
9371076

9381077
size_t M, K, N;// Matrix dimensions

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp