豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 35c4bc8

Browse files
authored
fix: k_grouped_fp8_gemm_nt_contiguous crashes with n = 768 on H100 (#238)
1 parent 477618c commit 35c4bc8

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
167167
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
168168
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
169169
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
170-
uint32_t last_group_idx = kNumGroups, sum_k = 0;
170+
uint32_t last_group_idx = kNumGroups;
171+
uint32_t prefetched_next_group_idx = kNumGroups; // Track which group was prefetched
171172

172173
// Persistently schedule over blocks
173174
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
@@ -187,16 +188,45 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
187188
const uint32_t& next_stage_idx = stage_idx ^ 1;
188189
last_group_idx = scheduler.current_group_idx;
189190

190-
// Prepare next tensor map
191-
sum_k += scheduler.current_shape_k;
191+
// Check if the current group matches the prefetched group
192+
// If not, we need to prepare the correct tensor map for the current group
193+
if (scheduler.current_num_valid_groups > 0 &&
194+
scheduler.current_group_idx != prefetched_next_group_idx) {
195+
// The prefetched tensor map doesn't match current group
196+
// This happens when block count is small (< num_SMs) and scheduler skips groups
197+
// Need to prepare the correct tensor map for current group
198+
// Use scheduler.current_k_cumsum which correctly tracks k offset even when groups are skipped
199+
const uint64_t current_k_offset = scheduler.current_k_cumsum;
200+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[stage_idx],
201+
gmem_a_ptr + current_k_offset * shape_m);
202+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[stage_idx],
203+
gmem_b_ptr + current_k_offset * shape_n);
204+
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[stage_idx],
205+
scheduler.current_shape_k, scheduler.current_shape_k);
206+
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[stage_idx],
207+
scheduler.current_shape_k, scheduler.current_shape_k);
208+
*(gmem_tensor_map_a[stage_idx]) = *(smem_tensor_map_a[stage_idx]);
209+
*(gmem_tensor_map_b[stage_idx]) = *(smem_tensor_map_b[stage_idx]);
210+
// NOTE: Don't call tensor_map_release_cta() here!
211+
// We're preparing the current tensor map, not the next one.
212+
// It will be acquired immediately in the "Get current tensor map" section below.
213+
}
214+
215+
// Prepare next tensor map (prefetch for next group)
192216
if (scheduler.next_group_idx < kNumGroups) {
193-
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast<uint64_t>(sum_k) * shape_m);
194-
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast<uint64_t>(sum_k) * shape_n);
217+
// Calculate next group's k offset using scheduler-provided information
218+
// This ensures consistency even when groups are skipped
219+
const uint64_t next_k_offset = static_cast<uint64_t>(scheduler.current_k_cumsum) + scheduler.current_shape_k;
220+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + next_k_offset * shape_m);
221+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + next_k_offset * shape_n);
195222
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
196223
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
197224
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
198225
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
199226
tensor_map_release_cta();
227+
prefetched_next_group_idx = scheduler.next_group_idx; // Record which group was prefetched
228+
} else {
229+
prefetched_next_group_idx = kNumGroups; // No more groups to prefetch
200230
}
201231

202232
// Get current tensor map

0 commit comments

Comments
 (0)