豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

fix: correct fused KV cache stride in paged MQA logits#311

Open
JasonOA888 wants to merge 1 commit intodeepseek-ai:mainfrom
JasonOA888:fix/paged-mqa-kv-cache-stride
Open

fix: correct fused KV cache stride in paged MQA logits#311
JasonOA888 wants to merge 1 commit intodeepseek-ai:mainfrom
JasonOA888:fix/paged-mqa-kv-cache-stride

Conversation

@JasonOA888
Copy link
Copy Markdown

Bug

fp8_fp4_paged_mqa_logits decomposes the fused KV cache into separate kv_cache and kv_cache_sf tensors via torch::from_blob. The strides passed to from_blob assume a flat layout (all values then all SF), but the API assertion enforces an interleaved layout ([values | SF] per row):

DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes);  // FP4: 68
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf);  // FP8: 132

What was wrong

kv_cache stride(1):

  • FP4: was head_dim/2 (64), should be fp4_with_sf_bytes (68)
  • FP8: was head_dim (128), should be head_dim_with_sf (132)

kv_cache_sf offset:

  • Was block_kv * head_dim, pointing past all data rows into undefined territory
  • Should be head_dim (or head_dim/2 for FP4), pointing to the SF of the first row

kv_cache_sf stride(1):

  • Was 1 (contiguous), but SF values are interleaved every fp4_with_sf_bytes/head_dim_with_sf bytes

Impact

TMA descriptors constructed with wrong strides cause misaligned reads for any row > 0 in the paged KV cache. This silently corrupts attention logits for multi-row blocks.

The test helper had the same flat-layout assumption, masking the bug for small block sizes where corruption fell in the last 1-2 rows.

Fix

  • Use the correct row stride (fp4_with_sf_bytes / head_dim_with_sf) for both kv_cache and kv_cache_sf from_blob
  • Fix kv_cache_sf offset to head_dim / head_dim / 2 (start of first row's SF)
  • Fix test helpers to construct proper interleaved layout

The from_blob calls for kv_cache and kv_cache_sf used wrong strides
that assumed a flat layout (all values then all SF), but the API
assertion enforces an interleaved layout ([values|SF] per row).

FP4 path: stride(1) was head_dim/2 (64), should be fp4_with_sf_bytes (68)
FP8 path: stride(1) was head_dim (128), should be head_dim_with_sf (132)

kv_cache_sf offset was also wrong (block_kv * head_dim instead of
just head_dim), and stride(1) was 1 instead of the row stride
divided by element size.

Also fixes the test helpers to construct the correct interleaved
layout instead of the flat layout that only worked by coincidence
for small block sizes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant