@@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613return {unrolledCode, workgroupSize, precision};
614614}
615615
616+ inline KernelCodecreateMatmul12 (const char *shaderTemplate,const size_t M,
617+ const size_t K,const size_t N,
618+ const size_t TM,const size_t TN,
619+ const size_t LID,
620+ const Shape &workgroupSize = {256 ,1 ,1 },
621+ NumType precision = kf32) {
622+ std::stringcodeString (shaderTemplate);
623+ replaceAll (codeString, {{" {{precision}}" ,toString (precision)},
624+ {" {{M}}" ,toString (M)},
625+ {" {{K}}" ,toString (K)},
626+ {" {{N}}" ,toString (N)},
627+ {" {{TM}}" ,toString (TM)},
628+ {" {{TN}}" ,toString (TN)},
629+ {" {{LID}}" ,toString (LID)}
630+ });
631+ return {loopUnrolling (codeString), workgroupSize, precision};
632+ }
633+
634+ // ─────────────────────────────────────────────────────────────────────────────
635+ // Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
636+ // and subgroupMatrixMultiplyAccumulate
637+ // ─────────────────────────────────────────────────────────────────────────────
638+ const char *kShaderSubgroupMatrixMultiply =R"(
639+ enable chromium_experimental_subgroup_matrix;
640+ diagnostic (off, chromium.subgroup_matrix_uniformity);
641+
642+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+ @compute @workgroup_size({{workgroupSize}})
647+ fn main(@builtin(workgroup_id) wg: vec3<u32>,
648+ @builtin(local_invocation_id) localID : vec3<u32>) {
649+
650+ let rowStart: u32 = wg.x * 8u * {{TM}};
651+ let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
652+
653+ let baseA: u32 = rowStart * {{K}};
654+ let baseB: u32 = colStart;
655+ let cBase: u32 = rowStart * {{N}} + colStart;
656+
657+ var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
658+ var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
659+
660+ // 4x4 accumulators (8x8 each)
661+ var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
662+
663+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
664+ Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
665+ }
666+
667+ for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
668+ Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
669+ }
670+
671+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
672+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
673+ accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
674+ }
675+ }
676+
677+ for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
678+ workgroupBarrier();
679+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
680+ Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
681+ }
682+
683+ for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
684+ Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
685+ }
686+
687+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
688+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
689+ accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
690+ }
691+ }
692+ }
693+
694+ workgroupBarrier();
695+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
696+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
697+ subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
698+ }
699+ }
700+ }
701+ )" ;
702+
616703/* *
617704 * @brief No-Op shader with matmul bindings for performance testing
618705*/
@@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
683770const Bindings</* input, weights, output*/ 3 > &bindings,
684771size_t M,size_t K,size_t N, NumType numtype) {
685772 Kernel kernel;
773+ CompilationInfo info;
686774if (version ==1 ) {
687775 Shape wgSize = {256 ,1 ,1 };
688776 Shape nWorkgroups =cdiv ({M, N,1 }, {16 ,16 ,1 });
689777 KernelCode matmul =createNoOp (kShaderNoOp ,/* wgsize*/ wgSize);
690778 kernel =createKernel (ctx, matmul, bindings,
691- /* nWorkgroups*/ nWorkgroups);
779+ /* nWorkgroups*/ nWorkgroups,
780+ NoParam{}, &info);
692781 }else if (version ==2 ) {
693782 Shape wgSize = {16 ,16 ,1 };
694783LOG (kDefLog ,kInfo ," wgSize: %s" ,toString (wgSize).c_str ());
695784 KernelCode matmul =
696785createMatmul1 (kShaderMatmul1 , M, K, N,/* wgsize*/ wgSize, numtype);
697786 kernel =createKernel (ctx, matmul, bindings,
698- /* nWorkgroups*/ cdiv ({M, N,1 }, wgSize));
787+ /* nWorkgroups*/ cdiv ({M, N,1 }, wgSize),
788+ NoParam{}, &info);
699789 }else if (version ==3 ) {
700790static constexpr size_t tileSize =16 ;
701791 KernelCode matmul =createMatmul2 (kShaderMatmul2 , M, K, N,
702792/* wgSize*/ {tileSize * tileSize,1 ,1 }, numtype);
703793 kernel =
704794createKernel (ctx, matmul, bindings,
705- /* nWorkgroups*/ cdiv ({M, N,1 }, {tileSize, tileSize,1 }));
795+ /* nWorkgroups*/ cdiv ({M, N,1 }, {tileSize, tileSize,1 }),
796+ NoParam{}, &info);
706797 }else if (version ==4 || version ==6 ) {
707798static constexpr size_t BM =64 ;
708799static constexpr size_t BK =4 ;
@@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
721812 numtype,
722813/* Loop unrolling*/ version ==6 ?true :false );
723814 kernel =createKernel (ctx, matmul, bindings,
724- /* nWorkgroups*/ nWorkgroups);
815+ /* nWorkgroups*/ nWorkgroups,
816+ NoParam{}, &info);
725817 }else if (version ==5 || version ==7 ) {
726818static constexpr size_t BM =64 ;
727819static constexpr size_t BK =8 ;
@@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
739831 numtype,
740832/* Loop unrolling*/ version ==7 ?true :false );
741833 kernel =createKernel (ctx, matmul, bindings,
742- /* nWorkgroups*/ nWorkgroups);
834+ /* nWorkgroups*/ nWorkgroups,
835+ NoParam{}, &info);
743836 }else if (version ==8 || version ==10 ) {
744837static constexpr size_t BM =64 ;
745838static constexpr size_t BK =8 ;
@@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
757850 numtype,
758851/* Loop unrolling*/ true );
759852 kernel =createKernel (ctx, matmul, bindings,
760- /* nWorkgroups*/ nWorkgroups);
853+ /* nWorkgroups*/ nWorkgroups,
854+ NoParam{}, &info);
761855 }else if (version ==9 || version ==11 ) {
762856static constexpr size_t BM =64 ;
763857static constexpr size_t BK =8 ;
@@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
774868/* wgSize*/ wgSize,
775869 numtype);
776870 kernel =createKernel (ctx, matmul, bindings,
777- /* nWorkgroups*/ nWorkgroups);
871+ /* nWorkgroups*/ nWorkgroups,
872+ NoParam{}, &info);
873+ }else if (version ==12 || version ==13 ) {
874+ // f16: Subgroup matrix multiply
875+ static constexpr size_t TM =4 ;
876+ static constexpr size_t TN =8 ;
877+ static constexpr size_t LID =2 ;
878+ Shape wgSize = {32 , LID,1 };// One subgroup per workgroup
879+ Shape nWorkgroups = {cdiv (M,8 * TM),cdiv (N,8 * TN * LID),1 };
880+ LOG (kDefLog ,kInfo ," M: %zu, K: %zu, N: %zu" , M, K, N);
881+ LOG (kDefLog ,kInfo ," wgSize: ( %s )" ,toString (wgSize).c_str ());
882+ LOG (kDefLog ,kInfo ," nWorkgroups: ( %s )" ,toString (nWorkgroups).c_str ());
883+ KernelCode matmul =createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, LID, wgSize, numtype);
884+ kernel =createKernel (ctx, matmul, bindings, nWorkgroups,
885+ NoParam{}, &info);
886+ }
887+
888+ if (info.status != WGPUCompilationInfoRequestStatus_Success) {
889+ LOG (kDefLog ,kError ," Failed to compile shader" );
890+ for (size_t i =0 ; i < info.messages .size (); i++) {
891+ LOG (kDefLog ,kError ," Line %llu, Pos %llu: %s" , info.lineNums [i],
892+ info.linePos [i], info.messages [i].c_str ());
893+ }
894+ exit (1 );
895+ }else {
896+ LOG (kDefLog ,kInfo ," Shader compiled successfully" );
897+ for (size_t i =0 ; i < info.messages .size (); i++) {
898+ LOG (kDefLog ,kInfo ," Line %llu, Pos %llu: %s" , info.lineNums [i],
899+ info.linePos [i], info.messages [i].c_str ());
900+ }
778901 }
902+
779903return kernel;
780904}
781905
@@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
791915assert (numtype == kf16);
792916 }
793917
794- // Allocate GPU buffers and copy data
795- WGPUDeviceDescriptor devDescriptor = {} ;
796- devDescriptor. requiredFeatureCount = 1 ;
797- devDescriptor. requiredFeatures =std::array{WGPUFeatureName_ShaderF16}. data () ;
798-
799- Context ctx;
800- if (numtype == kf16) {
801- ctx = createContext (
802- {}, {} ,
803- /* device descriptor, enabling f16 in WGSL */
804- {
805- . requiredFeatureCount = 1 ,
806- . requiredFeatures = std::array{WGPUFeatureName_ShaderF16}. data ()
807- } );
808- if (ctx. adapterStatus != WGPURequestAdapterStatus_Success) {
809- LOG ( kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9). " );
810- exit ( 1 );
918+ static WGPUDawnTogglesDescriptor toggles = {};
919+ toggles. chain . sType = WGPUSType_DawnTogglesDescriptor ;
920+ const char * enableList[] = { " allow_unsafe_apis " } ;
921+ toggles. enabledToggles =enableList ;
922+ toggles. enabledToggleCount = 1 ;
923+
924+ static WGPUDeviceDescriptor devDesc = {};
925+ devDesc. nextInChain = &toggles. chain ;
926+ devDesc. requiredFeatureCount = 3 ,
927+ devDesc. requiredFeatures = std::array{
928+ WGPUFeatureName_ShaderF16,
929+ WGPUFeatureName_Subgroups ,
930+ WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
931+ }. data ( );
932+ devDesc. uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
933+ . callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void *, void *) {
934+ LOG ( kDefLog , kError , " [Uncaptured %d] %.*s \n " , ( int )type, ( int )msg. length , msg. data );
811935 }
812- if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813- LOG (kDefLog ,kError ," Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
814- exit (1 );
936+ };
937+ devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
938+ .mode = WGPUCallbackMode_AllowSpontaneous,
939+ .callback = [](WGPUDeviceconst * device, WGPUDeviceLostReason reason, WGPUStringView msg,void *,void *) {
940+ LOG (kDefLog ,kError ," [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
815941 }
816- }
942+ };
817943
818- if (numtype == kf32) {
819- ctx =createContext ({}, {}, {});
820- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821- ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822- LOG (kDefLog ,kError ," Failed to create adapter or device" );
823- // stop execution
824- exit (1 );
825- }else {
826- LOG (kDefLog ,kInfo ," Successfully created adapter and device" );
944+ static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
945+ devDesc.requiredLimits = &requiredLimits;
946+ Context ctx =createContext ({}, {}, devDesc);
947+
948+ WGPULoggingCallbackInfo logCb{
949+ .callback = [](WGPULoggingType type, WGPUStringView msg,void *,void *) {
950+ LOG (kDefLog ,kError ," [WGPU %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
827951 }
828- }
952+ };
953+ wgpuDeviceSetLoggingCallback (ctx.device , logCb);
954+
955+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
956+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
957+ LOG (kDefLog ,kError ," Failed to create adapter or device" );
958+ // stop execution
959+ exit (1 );
960+ }else {
961+ LOG (kDefLog ,kInfo ," Successfully created adapter and device" );
962+ }
829963
830964 Tensor input =createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
831965 Tensor weights =createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ());// column-major
@@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
859993// Use microsecond for more accurate time measurement
860994auto duration =
861995 std::chrono::duration_cast<std::chrono::microseconds>(end - start);
862- float gflops =2 * M * N *
996+ float gflops =2 . 0f * M * N *
863997 K /// factor of 2 for multiplication & accumulation
864998 (static_cast <double >(duration.count ()) /1000000.0 ) /
8659991000000000.0 *static_cast <float >(nIter);
@@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
8701004 show<precision>(outputPtr.get (), M, N," Output[0]" ).c_str ());
8711005
8721006LOG (kDefLog ,kInfo ," \n\n ===================================================================="
873- " ============\n Execution Time: (M = %d , K = %d , N = %d ) x %d iterations"
1007+ " ============\n Execution Time: (M = %zu , K = %zu , N = %zu ) x %zu iterations"
8741008" :\n %.1f"
8751009" milliseconds / dispatch ~ %.2f"
8761010" GFLOPS\n ================================================================"
@@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
9131047case 9 :return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
9141048case 10 :return " f16: 2D blocktiling with loop unrolling and vectorization" ;
9151049case 11 :return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
1050+ case 12 :return " f16: Subgroup matrix multiply with transpose (default)" ;
1051+ case 13 :return " f32: Subgroup matrix multiply with transpose" ;
9161052default :return " Not specified" ;
9171053 }
9181054}
9191055
9201056int main () {
1057+ std::cout <<" Starting matmul test..." << std::endl;
9211058char * version_str =getenv (" MATMUL_VERSION" );
922- int version = version_str ==NULL ?10 :atoi (version_str);
1059+ int version = version_str ==NULL ?12 :atoi (version_str);
9231060// 1 == f32: No-Op
9241061// 2 == f32: naive matmul
9251062// 3 == f32: tiling
@@ -931,8 +1068,10 @@ int main() {
9311068// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9321069// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9331070// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
934- bool enableF16 = version ==10 || version ==11 ;
935- bool transposedInput = version ==9 || version ==11 ;
1071+ // 12 == f16: Subgroup matrix multiply with transpose (default)
1072+ // 13 == f32: Subgroup matrix multiply with transpose
1073+ bool enableF16 = version ==10 || version ==11 || version ==12 ;
1074+ bool transposedInput = version ==9 || version ==11 || version ==12 || version ==13 ;
9361075 NumType numtype = enableF16 ? kf16 : kf32;
9371076
9381077size_t M, K, N;// Matrix dimensions