Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita21ff37

Browse files
committed
consider breakpoint
1 parent857c5b8 commita21ff37

File tree

4 files changed

+56
-15
lines changed

4 files changed

+56
-15
lines changed

‎csrc/scheduler/pointwise.cpp‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ bool mayHaveTmaCompatibleInputs(
340340
continue;
341341
}
342342

343-
if (!scheduler_utils::isTvSizeSuitableForTma(tv, runtime_info)) {
343+
if (!scheduler_utils::isTvSizeSuitableForTma(
344+
tv, runtime_info,/*break_point =*/0)) {
344345
continue;
345346
}
346347

‎csrc/scheduler/pointwise_tma.cpp‎

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,27 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
7070
}
7171
// must be cacheable
7272
if (scheduler_utils::getCacheableUses(tv).empty()) {
73+
scheduler_debug_utils::log(
74+
"[Pointwise TMA scheduler] no cacheable uses, tv:", tv->toString());
7375
continue;
7476
}
7577

7678
// must be suitable for TMA based on the number of elements and dtype size
77-
if (!scheduler_utils::isTvSizeSuitableForTma(tv, runtime_info)) {
79+
if (!scheduler_utils::isTvSizeSuitableForTma(
80+
tv, runtime_info, break_point)) {
81+
scheduler_debug_utils::log(
82+
"[Pointwise TMA scheduler] size not suitable for TMA, tv:",
83+
tv->toString());
7884
continue;
7985
}
8086

8187
// must have the same number of logical dimensions as the reference tensor
8288
// to avoid loading tensors that are smaller than the reference tensor
8389
if (scheduler_utils::nLogicalDims(tv) != n_valid_dims) {
90+
scheduler_debug_utils::log(
91+
"[Pointwise TMA scheduler] number of logical dimensions not suitable"
92+
"for TMA, tv:",
93+
tv->toString());
8494
continue;
8595
}
8696

@@ -92,6 +102,8 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
92102
[](const std::optional<bool>& contiguity) {
93103
return !contiguity.has_value() || !contiguity.value();
94104
})) {
105+
scheduler_debug_utils::log(
106+
"[Pointwise TMA scheduler] not contiguous, tv:", tv->toString());
95107
continue;
96108
}
97109

@@ -100,8 +112,14 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
100112
// To use TMA, this tv must have both lhs and rhs.
101113
// see PointwiseTest.BroadcastAddInner for example.
102114
if ((int64_t)tv->getLoopDomain().size() <= break_point) {
115+
scheduler_debug_utils::log(
116+
"[Pointwise TMA scheduler] break point not suitable for TMA, tv:",
117+
tv->toString());
103118
continue;
104119
}
120+
tma_compatible_input_indices.push_back(input_idx);
121+
scheduler_debug_utils::log(
122+
"[Pointwise TMA scheduler] suitable for TMA, tv:", tv->toString());
105123
}
106124
return tma_compatible_input_indices;
107125
}
@@ -119,9 +137,15 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
119137
params->cparams.index_type = prop.index_type;
120138
params->use_tma_load =true;
121139

140+
auto bp_info =pointwise_utils::getBreakPoint(
141+
fusion, prop, data_cache,/*is_tma =*/true);
142+
params->break_point = bp_info.break_point;
143+
122144
auto tma_compatible_input_indices =getTmaCompatibleInputIndices(
123145
fusion, runtime_info, prop.largest_out, params->break_point);
124146
if (tma_compatible_input_indices.empty()) {
147+
scheduler_debug_utils::log(
148+
"[Pointwise TMA scheduler] no suitable inputs found");
125149
returnnullptr;
126150
}else {
127151
params->tma_compatible_input_indices = tma_compatible_input_indices;
@@ -135,6 +159,8 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
135159
int64_t tma_domain_inner =scheduler_utils::getInnerTmaDomainSize(
136160
prop.n_elems, target_inner_tma_domain_size, min_dtype_bits);
137161
if (tma_domain_inner ==1 || prop.n_elems % tma_domain_inner !=0) {
162+
scheduler_debug_utils::log(
163+
"[Pointwise TMA scheduler] tma_domain_inner is not suitable for TMA");
138164
returnnullptr;
139165
}
140166
// constexpr int64_t align_bytes = 16;
@@ -146,10 +172,6 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
146172
constint64_t tma_outer_domain_size = prop.n_elems / tma_domain_inner;
147173
params->tma_domain_inner = tma_domain_inner;
148174

149-
auto bp_info =pointwise_utils::getBreakPoint(
150-
fusion, prop, data_cache,/*is_tma =*/true);
151-
params->break_point = bp_info.break_point;
152-
153175
// Compute elements_per_cta: Each CTA issues one TMA load operation. We
154176
// calculate the number of elements per TMA load based on the required bits
155177
// in flight, assuming 8 CTAs per SM. This is a guideline; the actual tile
@@ -205,6 +227,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
205227
debug() <<"\n==== Pointwise TMA Scheduler Heuristics ====\n";
206228
debug() <<"Domain sizes:\n";
207229
debug() <<" n_elems:" << prop.n_elems <<"\n";
230+
debug() <<" elem_counts:";
231+
for (auto element : prop.elem_counts) {
232+
debug() << element <<",";
233+
}
234+
debug() <<"\n";
208235
debug() <<" break_point:" << bp_info.break_point <<"\n";
209236
debug() <<" tma_domain_inner:" << tma_domain_inner <<"\n";
210237
debug() <<" tma_outer_domain_size:" << tma_outer_domain_size <<"\n";

‎csrc/scheduler/utils.cpp‎

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,31 +3445,43 @@ int64_t getInnerTmaDomainSize(
34453445
return best_divisible_size;
34463446
}
34473447

3448-
int64_tgetNumElements(
3448+
std::pair<std::vector<int64_t>,int64_t>getNumElements(
34493449
const TensorView* tv,
34503450
SchedulerRuntimeInfo& runtime_info) {
34513451
int64_t num_elements =1;
3452+
std::vector<int64_t> elem_counts;
3453+
elem_counts.reserve(tv->getLogicalDomain().size());
34523454
for (auto logical_id : tv->getLogicalDomain()) {
34533455
auto inferred_val =
34543456
runtime_info.expressionEvaluator().evaluate(logical_id->extent());
34553457
NVF_ERROR(
34563458
inferred_val.hasValue(),
34573459
"Error inferring extent of:",
34583460
logical_id->toString());
3459-
num_elements *= inferred_val.as<int64_t>();
3461+
auto extent = inferred_val.as<int64_t>();
3462+
elem_counts.push_back(extent);
3463+
num_elements *= extent;
34603464
}
3461-
return num_elements;
3465+
return{elem_counts,num_elements};
34623466
}
34633467

34643468
boolisTvSizeSuitableForTma(
34653469
const TensorView* tv,
3466-
SchedulerRuntimeInfo& runtime_info) {
3470+
SchedulerRuntimeInfo& runtime_info,
3471+
int64_t break_point) {
34673472
auto dtype_bits =
34683473
dataTypeSizeBit(tv->getDataType().value(), runtime_info.getIndexType());
3469-
auto elem_count =getNumElements(tv, runtime_info);
3474+
auto [elem_counts, total_elem_count] =getNumElements(tv, runtime_info);
3475+
int64_t inner_elem_count = break_point ==0
3476+
? total_elem_count
3477+
:std::accumulate(
3478+
elem_counts.begin() + break_point,
3479+
elem_counts.end(),
3480+
1,
3481+
std::multiplies<int64_t>());
34703482
constint64_t min_inner_tma_domain_size =2 *128 / dtype_bits;
3471-
if (elem_count % min_inner_tma_domain_size ==0 &&
3472-
elem_count > min_inner_tma_domain_size) {
3483+
if (inner_elem_count % min_inner_tma_domain_size ==0 &&
3484+
inner_elem_count > min_inner_tma_domain_size) {
34733485
returntrue;
34743486
}
34753487
returnfalse;

‎csrc/scheduler/utils.h‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ int64_t getInnerTmaDomainSize(
10071007
int64_t min_dtype_bits =8);
10081008

10091009
// Get the total number of elements in a given TensorView
1010-
int64_tgetNumElements(
1010+
std::pair<std::vector<int64_t>,int64_t>getNumElements(
10111011
const TensorView* tv,
10121012
SchedulerRuntimeInfo& runtime_info);
10131013

@@ -1019,7 +1019,8 @@ int64_t getNumElements(
10191019
// outer TMA domain is 1, which is not a valid 2D TMA configuration.
10201020
boolisTvSizeSuitableForTma(
10211021
const TensorView* tv,
1022-
SchedulerRuntimeInfo& runtime_info);
1022+
SchedulerRuntimeInfo& runtime_info,
1023+
int64_t break_point);
10231024
}// namespace scheduler_utils
10241025

10251026
}// namespace nvfuser

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp