support group_gemm_offset, group_gemm_offset_swapAB#116
support group_gemm_offset, group_gemm_offset_swapAB#116Wangzheee wants to merge 5 commits intodeepseek-ai:mainfrom
Conversation
|
Thanks for your contribution! We will merge it after the refactor #112. |
Thank you for your reply. |
|
I reproduce the benchmark results on H20: It should be noted that the above results' TFLOPS is measured by setting all elements as 1. Refer to strangely-matrix-multiplications, the initialization method will influence the benchmark result. I also test on H100: The above results shows that on H100, this approach couldn't bring any benefits. And there might be a bug related to the block assumption. |
This test has already set random(0.7,1.3) |
When I run the original code with As for how I run the test, just run if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
# test_gemm()
# test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_m_grouped_gemm_offset()
# test_wgrad_gemm()
# test_k_grouped_wgrad_gemm() |
|
how to integrate offset gemm with deepep low latency kernels? there seems to be a huge gap. @Wangzheee |
|
Closing this as duplicated with #192. Thanks! |
support group gemm offset type: group_gemm_offset, and group_gemm_offset_swapAB
Perf (num_groups=2, expected_m_per_group= 16, n=4096, k=7168): 36 us | throughput: 53 TFLOPS, 1665 GB/s
Perf (num_groups=4, expected_m_per_group= 16, n=4096, k=7168): 65 us | throughput: 58 TFLOPS, 1813 GB/s
Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1685 GB/s
Perf (num_groups=9, expected_m_per_group= 32, n=4096, k=7168): 141 us | throughput: 120 TFLOPS, 1900 GB/s
Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1689 GB/s
Perf (num_groups=4, expected_m_per_group= 32, n=4096, k=7168): 66 us | throughput: 115 TFLOPS, 1822 GB/s
Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 485 us | throughput: 248 TFLOPS, 2002 GB/s
Perf (num_groups= 2, expected_m_per_group= 16, n=4096, k=7168): 27 us | throughput: 71 TFLOPS, 2226 GB/s
Perf (num_groups= 4, expected_m_per_group= 16, n=4096, k=7168): 46 us | throughput: 82 TFLOPS, 2587 GB/s
Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 134 TFLOPS, 2136 GB/s
Perf (num_groups= 9, expected_m_per_group= 32, n=4096, k=7168): 93 us | throughput: 183 TFLOPS, 2902 GB/s
Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 135 TFLOPS, 2143 GB/s
Perf (num_groups= 4, expected_m_per_group= 32, n=4096, k=7168): 49 us | throughput: 152 TFLOPS, 2414 GB/s
Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 479 us | throughput: 251 TFLOPS, 2029 GB/s