@@ -70,17 +70,27 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
7070 }
7171// must be cacheable
7272if (scheduler_utils::getCacheableUses (tv).empty ()) {
73+ scheduler_debug_utils::log (
74+ " [Pointwise TMA scheduler] no cacheable uses, tv:" , tv->toString ());
7375continue ;
7476 }
7577
7678// must be suitable for TMA based on the number of elements and dtype size
77- if (!scheduler_utils::isTvSizeSuitableForTma (tv, runtime_info)) {
79+ if (!scheduler_utils::isTvSizeSuitableForTma (
80+ tv, runtime_info, break_point)) {
81+ scheduler_debug_utils::log (
82+ " [Pointwise TMA scheduler] size not suitable for TMA, tv:" ,
83+ tv->toString ());
7884continue ;
7985 }
8086
8187// must have the same number of logical dimensions as the reference tensor
8288// to avoid loading tensors that are smaller than the reference tensor
8389if (scheduler_utils::nLogicalDims (tv) != n_valid_dims) {
90+ scheduler_debug_utils::log (
91+ " [Pointwise TMA scheduler] number of logical dimensions not suitable"
92+ " for TMA, tv:" ,
93+ tv->toString ());
8494continue ;
8595 }
8696
@@ -92,6 +102,8 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
92102 [](const std::optional<bool >& contiguity) {
93103return !contiguity.has_value () || !contiguity.value ();
94104 })) {
105+ scheduler_debug_utils::log (
106+ " [Pointwise TMA scheduler] not contiguous, tv:" , tv->toString ());
95107continue ;
96108 }
97109
@@ -100,8 +112,14 @@ std::vector<int64_t> getTmaCompatibleInputIndices(
100112// To use TMA, this tv must have both lhs and rhs.
101113// see PointwiseTest.BroadcastAddInner for example.
102114if ((int64_t )tv->getLoopDomain ().size () <= break_point) {
115+ scheduler_debug_utils::log (
116+ " [Pointwise TMA scheduler] break point not suitable for TMA, tv:" ,
117+ tv->toString ());
103118continue ;
104119 }
120+ tma_compatible_input_indices.push_back (input_idx);
121+ scheduler_debug_utils::log (
122+ " [Pointwise TMA scheduler] suitable for TMA, tv:" , tv->toString ());
105123 }
106124return tma_compatible_input_indices;
107125}
@@ -119,9 +137,15 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
119137 params->cparams .index_type = prop.index_type ;
120138 params->use_tma_load =true ;
121139
140+ auto bp_info =pointwise_utils::getBreakPoint (
141+ fusion, prop, data_cache,/* is_tma =*/ true );
142+ params->break_point = bp_info.break_point ;
143+
122144auto tma_compatible_input_indices =getTmaCompatibleInputIndices (
123145 fusion, runtime_info, prop.largest_out , params->break_point );
124146if (tma_compatible_input_indices.empty ()) {
147+ scheduler_debug_utils::log (
148+ " [Pointwise TMA scheduler] no suitable inputs found" );
125149return nullptr ;
126150 }else {
127151 params->tma_compatible_input_indices = tma_compatible_input_indices;
@@ -135,6 +159,8 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
135159int64_t tma_domain_inner =scheduler_utils::getInnerTmaDomainSize (
136160 prop.n_elems , target_inner_tma_domain_size, min_dtype_bits);
137161if (tma_domain_inner ==1 || prop.n_elems % tma_domain_inner !=0 ) {
162+ scheduler_debug_utils::log (
163+ " [Pointwise TMA scheduler] tma_domain_inner is not suitable for TMA" );
138164return nullptr ;
139165 }
140166// constexpr int64_t align_bytes = 16;
@@ -146,10 +172,6 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
146172const int64_t tma_outer_domain_size = prop.n_elems / tma_domain_inner;
147173 params->tma_domain_inner = tma_domain_inner;
148174
149- auto bp_info =pointwise_utils::getBreakPoint (
150- fusion, prop, data_cache,/* is_tma =*/ true );
151- params->break_point = bp_info.break_point ;
152-
153175// Compute elements_per_cta: Each CTA issues one TMA load operation. We
154176// calculate the number of elements per TMA load based on the required bits
155177// in flight, assuming 8 CTAs per SM. This is a guideline; the actual tile
@@ -205,6 +227,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
205227debug () <<" \n ==== Pointwise TMA Scheduler Heuristics ====\n " ;
206228debug () <<" Domain sizes:\n " ;
207229debug () <<" n_elems:" << prop.n_elems <<" \n " ;
230+ debug () <<" elem_counts:" ;
231+ for (auto element : prop.elem_counts ) {
232+ debug () << element <<" ," ;
233+ }
234+ debug () <<" \n " ;
208235debug () <<" break_point:" << bp_info.break_point <<" \n " ;
209236debug () <<" tma_domain_inner:" << tma_domain_inner <<" \n " ;
210237debug () <<" tma_outer_domain_size:" << tma_outer_domain_size <<" \n " ;