@@ -167,7 +167,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
167167 if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync ()) {
168168 const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
169169 const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
170- uint32_t last_group_idx = kNumGroups , sum_k = 0 ;
170+ uint32_t last_group_idx = kNumGroups ;
171+ uint32_t prefetched_next_group_idx = kNumGroups ; // Track which group was prefetched
171172
172173 // Persistently schedule over blocks
173174 while (scheduler.get_next_block (m_block_idx, n_block_idx)) {
@@ -187,16 +188,45 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
187188 const uint32_t & next_stage_idx = stage_idx ^ 1 ;
188189 last_group_idx = scheduler.current_group_idx ;
189190
190- // Prepare next tensor map
191- sum_k += scheduler.current_shape_k ;
191+ // Check if the current group matches the prefetched group
192+ // If not, we need to prepare the correct tensor map for the current group
193+ if (scheduler.current_num_valid_groups > 0 &&
194+ scheduler.current_group_idx != prefetched_next_group_idx) {
195+ // The prefetched tensor map doesn't match current group
196+ // This happens when block count is small (< num_SMs) and scheduler skips groups
197+ // Need to prepare the correct tensor map for current group
198+ // Use scheduler.current_k_cumsum which correctly tracks k offset even when groups are skipped
199+ const uint64_t current_k_offset = scheduler.current_k_cumsum ;
200+ tensor_map_replace_global_addr_in_smem (smem_tensor_map_a[stage_idx],
201+ gmem_a_ptr + current_k_offset * shape_m);
202+ tensor_map_replace_global_addr_in_smem (smem_tensor_map_b[stage_idx],
203+ gmem_b_ptr + current_k_offset * shape_n);
204+ tensor_map_replace_global_inner_dim_stride_in_smem (smem_tensor_map_a[stage_idx],
205+ scheduler.current_shape_k , scheduler.current_shape_k );
206+ tensor_map_replace_global_inner_dim_stride_in_smem (smem_tensor_map_b[stage_idx],
207+ scheduler.current_shape_k , scheduler.current_shape_k );
208+ *(gmem_tensor_map_a[stage_idx]) = *(smem_tensor_map_a[stage_idx]);
209+ *(gmem_tensor_map_b[stage_idx]) = *(smem_tensor_map_b[stage_idx]);
210+ // NOTE: Don't call tensor_map_release_cta() here!
211+ // We're preparing the current tensor map, not the next one.
212+ // It will be acquired immediately in the "Get current tensor map" section below.
213+ }
214+
215+ // Prepare next tensor map (prefetch for next group)
192216 if (scheduler.next_group_idx < kNumGroups ) {
193- tensor_map_replace_global_addr_in_smem (smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast <uint64_t >(sum_k) * shape_m);
194- tensor_map_replace_global_addr_in_smem (smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast <uint64_t >(sum_k) * shape_n);
217+ // Calculate next group's k offset using scheduler-provided information
218+ // This ensures consistency even when groups are skipped
219+ const uint64_t next_k_offset = static_cast <uint64_t >(scheduler.current_k_cumsum ) + scheduler.current_shape_k ;
220+ tensor_map_replace_global_addr_in_smem (smem_tensor_map_a[next_stage_idx], gmem_a_ptr + next_k_offset * shape_m);
221+ tensor_map_replace_global_addr_in_smem (smem_tensor_map_b[next_stage_idx], gmem_b_ptr + next_k_offset * shape_n);
195222 tensor_map_replace_global_inner_dim_stride_in_smem (smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k , scheduler.next_shape_k );
196223 tensor_map_replace_global_inner_dim_stride_in_smem (smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k , scheduler.next_shape_k );
197224 *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
198225 *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
199226 tensor_map_release_cta ();
227+ prefetched_next_group_idx = scheduler.next_group_idx ; // Record which group was prefetched
228+ } else {
229+ prefetched_next_group_idx = kNumGroups ; // No more groups to prefetch
200230 }
201231
202232 // Get current tensor map
0 commit comments