豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 4ff3f54

Browse files
committed
Fix misalign issue when kv_block is not aligned
1 parent 2ce865d commit 4ff3f54

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,15 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
229229
// TODO: deal with `-1`?
230230
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
231231
kv_block_idx_ptr = 0;
232-
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
233-
__ldg(reinterpret_cast<const idx_storage_t*>(block_table) + q_idx * block_table_stride / kNumBlocksPerMMA
234-
+ (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups))
235-
: idx_storage_t{0});
232+
uint32_t compute_block_kv_offset = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups);
233+
const auto* kv_block_idx_global = block_table + q_idx * block_table_stride + compute_block_kv_offset * kNumBlocksPerMMA;
234+
auto* kv_block_idx_reg = reinterpret_cast<uint32_t*>(&kv_block_idx_storage);
235+
#pragma unroll
236+
for (uint32_t i = 0; i < kNumBlocksPerMMA; ++ i) {
237+
kv_block_idx_reg[i] = compute_block_kv_offset < num_kv ? __ldg(kv_block_idx_global + i) : 0;
238+
}
236239
}
240+
237241
idx_storage_t kv_block_idx = shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
238242

239243
// Wait KV consumer release

0 commit comments

Comments
 (0)