豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 2ce865d

Browse files
committed
Fix H=32
1 parent 1389d59 commit 2ce865d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
278278
const auto& v_offset = lane_idx;
279279

280280
// Preload weights
281-
constexpr uint32_t kNumWeightsInReg = 52;
281+
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
282282
float weights[BLOCK_Q][kNumWeightsInReg];
283283
DG_STATIC_ASSERT(kNumWeightsInReg <= kNumHeads and kNumWeightsInReg % 4 == 0, "Invalid kNumWeightsInReg");
284284

0 commit comments

Comments
 (0)