Review updated until commitf7f8f85 DescriptionEnable PDL support in CUTLASS grouped GEMM with launch_with_pdl flag Add wait_for_prior_grid dependencies before cached inputs in cacheInputs utility Add launch_dependent_grid dependencies for terminating outputs in cacheAndForkOutputs utility Filter out schedule operations from vectorization in pointwise scheduler Add BasicOverlap test case to validate PDL functionality with input/output caching
Changes walkthrough
| Relevant files |
|---|
| Enhancement | sync_information.cppExpand RAW sync validation for schedule operations
csrc/device_lower/analysis/sync_information.cpp Add isScheduleOp check to RAW sync validation conditions Expand validation logic to handle schedule operations | +2/-1 | pointwise.cppFilter schedule operations from pointwise vectorization
csrc/scheduler/pointwise.cpp Add device_lower/utils.h include for schedule operation utilities Filter schedule operations from vectorization using copy_if Exclude schedule operations from vectorized_tvs collection | +6/-2 | utils.cppAdd PDL grid synchronization to cache utilities
csrc/scheduler/utils.cpp Add wait_for_prior_grid dependencies before cached inputs in cacheInputs Add launch_dependent_grid dependencies for terminating outputs in cacheAndForkOutputs Track original inputs/outputs for proper PDL synchronization | +27/-0 | group_mm.cuEnable PDL in CUTLASS grouped GEMM
cutlass/group_mm.cu Enable launch_with_pdl flag in CUTLASS grouped GEMM execution Add PDL support to gemm_op.run call with launch_with_pdl=true | +6/-1 |
| | Tests | test_pdl.cppAdd BasicOverlap PDL test case
tests/cpp/test_pdl.cpp Add BasicOverlap test case for PDL functionality Test input/output caching with grid synchronization Validate PDL compilation and execution | +76/-0 |
|
PR Reviewer Guide
Here are some key observations to aid the review process: | 🧪 PR contains tests | | ⚡ Recommended focus areas for review | PDL Synchronization LogicThe new PDL synchronization logic incacheInputs andcacheAndForkOutputs functions adds dependencies for grid synchronization. The implementation checks for terminating outputs before adding launch dependencies, but the correctness of this logic should be validated, especially the condition that checks if an output is in the terminating outputs list. if (!original_inputs.empty()) {for (auto&& [original, cached] :zip(original_inputs, cached_inputs)) {// Add wait for prior grid before getting the cached inputs TensorView* grid_wait =wait_for_prior_grid({original}); cached.first->addDependency(grid_wait); }}Schedule Operation FilteringThe vectorization logic now filters out schedule operations usingstd::copy_if with a lambda that checks!ir_utils::isScheduleOp(tv). This change affects how inputs are vectorized and could impact performance. The correctness of excluding schedule operations from vectorization should be validated. std::copy_if( consumer_tvs.begin(), consumer_tvs.end(), std::back_inserter(vectorized_tvs), [](auto tv) {return !ir_utils::isScheduleOp(tv); });Sync Analysis ExceptionAddedir_utils::isScheduleOp(consumer) as an exception condition in the sync analysis error check. This change relaxes the sync requirements for schedule operations, which could potentially lead to synchronization issues if not properly validated. ir_utils::isScheduleOp(consumer) || ir_utils::isLdMatrixOp(producer->definition()) || ir_utils::isStMatrixOp(consumer->definition()) || |
|
launch_with_pdlflag with bf16 grouped gemmwait_for_prior_gridbeforecacheAfterof inputs in utility functioncacheInputscacheAndForkOutputsbeforecacheBeforeof terminating outputs in utility functioncacheAndForkOutputsStack on: