豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 8fc8efc

Browse files
bucket-xvChenhao Xu
andauthored
fix: add nvlink barrier to ensure the workspace is cleaned before next stage (#238)
Co-authored-by: Chenhao Xu <xch@deepseek.com>
1 parent 6bef332 commit 8fc8efc

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,14 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
635635
}
636636
}
637637
}
638+
639+
// Wait for all ranks to finish cleaning
640+
comm::nvlink_barrier<kNumRanks, kNumSMs, kNumDispatchThreads, kDispatchGridSyncIndex>(
641+
workspace, sym_buffer, sm_idx, thread_idx,
642+
[=]() { cutlass::arch::NamedBarrier::sync(kNumDispatchThreads, kDispatchBarrierIdx); },
643+
/* Before the NVLink barrier, there is a grid sync */ true,
644+
/* One block is sufficient to block the next stage from launching */ false
645+
);
638646
} else if (warp_idx == kNumDispatchWarps) {
639647
// Adjust registers
640648
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();

0 commit comments

Comments
 (0)