Why does fp8_mqa_logits use four MATH WGs for computation? In typical Hopper implementations, only two MATH WGs are generally used. This is also the case in the DeepGEMM 1d1d and 1d2d implementations.
Is this design choice related to the KV block dimension being 256? If we change the KV block size to 128, would it impact performance?
Why does fp8_mqa_logits use four MATH WGs for computation? In typical Hopper implementations, only two MATH WGs are generally used. This is also the case in the DeepGEMM 1d1d and 1d2d implementations.
Is this design choice related to the KV block dimension being 256? If we change the KV block size to 128, would it impact performance?