@@ -180,7 +180,7 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
180180
181181 // Block scheduler
182182 uint32_t m_block_idx, n_block_idx;
183- auto scheduler = scheduler ::Scheduler<kGemmType , BLOCK_M, BLOCK_N, kNumGroups , kNumMulticast , kIsMulticastOnA , kNumSMs >(
183+ auto scheduler = sched ::Scheduler<kGemmType , BLOCK_M, BLOCK_N, kNumGroups , kNumMulticast , kIsMulticastOnA , kNumSMs >(
184184 shape_m, shape_n, shape_k, grouped_layout);
185185
186186 // Pipeline and TMA phases
@@ -209,19 +209,19 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
209209
210210 // Compute offsets
211211 // NOTES: the group is always concatenated with the outer dimension
212- uint32_t m_idx = scheduler.template get_global_idx <(kGemmType == GemmType::MGroupedMasked), scheduler ::IndexType::MN> (
212+ uint32_t m_idx = scheduler.template get_global_idx <(kGemmType == GemmType::MGroupedMasked), sched ::IndexType::MN> (
213213 shape_m, BLOCK_M, m_block_idx);
214- uint32_t n_idx = scheduler.template get_global_idx <(kMajorB == cute::UMMA::Major::K), scheduler ::IndexType::MN> (
214+ uint32_t n_idx = scheduler.template get_global_idx <(kMajorB == cute::UMMA::Major::K), sched ::IndexType::MN> (
215215 shape_n, BLOCK_N, n_block_idx, m_block_idx);
216216
217217 // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
218218 // And for all m-grouped GEMMs, A must be K-majored
219219 DG_STATIC_ASSERT (kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
220220 kMajorA == cute::UMMA::Major::K, " Invalid major" );
221221 uint32_t k_idx = k_block_idx * BLOCK_K;
222- uint32_t k_a_idx = scheduler.template get_global_idx <(kMajorA == cute::UMMA::Major::MN), scheduler ::IndexType::K> (
222+ uint32_t k_a_idx = scheduler.template get_global_idx <(kMajorA == cute::UMMA::Major::MN), sched ::IndexType::K> (
223223 shape_k, BLOCK_K, k_block_idx, m_block_idx);
224- uint32_t k_b_idx = scheduler.template get_global_idx <(kMajorB == cute::UMMA::Major::MN), scheduler ::IndexType::K> (
224+ uint32_t k_b_idx = scheduler.template get_global_idx <(kMajorB == cute::UMMA::Major::MN), sched ::IndexType::K> (
225225 shape_k, BLOCK_K, k_block_idx, m_block_idx);
226226
227227 // Add 2 CTA offsets
@@ -252,14 +252,14 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
252252 // No swizzling, so one TMA for one SF is enough
253253 if (k_block_idx % kNumSFAStagesPerLoad == 0 ) {
254254 uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
255- uint32_t sfa_k_idx = scheduler.template get_global_idx <(not is_m_grouped_contiguous (kGemmType )), scheduler ::IndexType::SF_K>(
255+ uint32_t sfa_k_idx = scheduler.template get_global_idx <(not is_m_grouped_contiguous (kGemmType )), sched ::IndexType::SF_K>(
256256 shape_sfa_k, 1 , math::ceil_div (k_idx, BLOCK_K * kNumSFAStagesPerLoad ));
257257 tma::copy<BLOCK_M, 1 , 0 >(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
258258 num_arrival_bytes += BLOCK_M * sizeof (uint32_t );
259259 }
260260 if (k_block_idx % kNumSFBStagesPerLoad == 0 ) {
261261 uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
262- uint32_t sfb_k_idx = scheduler.template get_global_idx <true , scheduler ::IndexType::SF_K>(
262+ uint32_t sfb_k_idx = scheduler.template get_global_idx <true , sched ::IndexType::SF_K>(
263263 shape_sfb_k, 1 , math::ceil_div (k_idx, BLOCK_K * kNumSFBStagesPerLoad ), m_block_idx);
264264 tma::copy<BLOCK_N, 1 , 0 >(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
265265 num_arrival_bytes += BLOCK_N * sizeof (uint32_t );
@@ -460,7 +460,7 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
460460 ptx::tcgen05_after_thread_sync ();
461461
462462 const auto tmem_base_addr = accum_stage_idx * UMMA_N;
463- const auto base_m_idx = scheduler.template get_global_idx <(not is_m_grouped_contiguous (kGemmType )), scheduler ::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
463+ const auto base_m_idx = scheduler.template get_global_idx <(not is_m_grouped_contiguous (kGemmType )), sched ::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
464464 const auto base_n_idx = n_block_idx * BLOCK_N;
465465
466466 if constexpr (kSwapAB ) {
0 commit comments