support swapAB for m_grouped_fp8_gemm_nt_masked#192
support swapAB for m_grouped_fp8_gemm_nt_masked#192Wangzheee wants to merge 1 commit intodeepseek-ai:mainfrom
Conversation
| block_ns.push_back(i); | ||
| if(get_env<int>("ENABLE_SWAPAB")){ | ||
| block_ms = std::vector{32}; // 32, 64 | ||
| block_ns = std::vector{256}; // 64, 128, 256 |
There was a problem hiding this comment.
Manually set one of them. Experiments have found that in most cases, 256 performs the best
|
Thanks! Merging it later. |
f2e2357 to
2991c77
Compare
Hi~ Do you have a plan for code merge? We have already used this PR on our online service for H20. |
|
LGTM |
|
为什么在other case里 num_groups=1, expected_m_per_group=1024, n=4096, k=7168 这个case也能有提升?num groups=1时实际上相当于是一个1024×4096×7168的矩阵乘吧?SwapAB在这里的优势是什么 |
Sorry, we will try to merge this by the end of Oct. As swap AB will introduce non-batch-invariant and deterministic issues, we will consider it more carefully and do some refactors before merging. Also, as most the code can be reused, we will also refactor the epilogue part to make this feature less change for the code. Thanks for your contribution! We will refactor for you, no change request👍🏻 cc @zheanxu |
The swapAB variant “swap” the WGMMA tile usage, mapping the original problem’s M dimension onto WGMMA’s N dimension (which must be a multiple of 8). This enables smaller BLOCK_M (32). The performance advantage primarily comes from finer tiling granularity and better resource utilization. |
Thanks~ Looking forward to the release of the new version. |
|
Hi @Wangzheee, I tested it on H100, and SwapAB seems to face performance degradation.
|
|
Hi, sorry about not having a lot of contexts on DeepGEMM. Do you know what's the difference between this and NVIDIA/TensorRT-LLM#4430 ? In that PR, there's also perf number attached. I'm seeing the number of this aligns with H20 perf between this PR and PR 4430. At the same case H100 is also seeing benefit. Thanks all! |
* Sync cluster before 2-CTA TMEM alloc * Minor fix * Minor fix
SwapAB: Significantly improve the performance for M%64<32
Description
How to use
Improvements (H20)
Aligned M, desired state: masked_m[j] = int(expected_m_per_group * random.uniform(1, 1))
Other case (original test): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
TODO