[feat] enable jit_kernels to skip launch#182
[feat] enable jit_kernels to skip launch#182Alcanderian wants to merge 2 commits intodeepseek-ai:mainfrom
Conversation
|
Thanks for this! My only comment is to skip launching in |
zhyncs
left a comment
There was a problem hiding this comment.
LGTM and this works well on sgl-kernel. Please fix @LyricZhao's comment. Thanks!
|
@zhyncs @Alcanderian After diving deep to the sglang code and DeepGEMM code, I believe we tend to spread the logic from sglang to DeepGEMM.
|
Actually, the reason we want to do warmup for the M=1-32K range is that the get_best_config function ( ) determines tiling and some other strategies based on the value of M. Different ranges of M trigger the compilation of different kernels. If this compilation were to occur during the service process, it would cause stuttering. |
I get it, but from my testing the warm up only generate 1 kernel for 1-32k range with N K binding. Maybe there is some issue with the get_best_config result ? |
|
Just run a test by printing out m,k,n with the code's hash and the combination is ~ grep 'hash:' result.log | awk -F',' '{print $2","$3","$4}' |sort -n |uniq
n:2112, k:7168, hash:ff553c76cf29a593aa64d3978324b464
n:6144, k:1536, hash:7fcb2cba293de06aed40e104a84329ed
n:7168, k:4096, hash:cd6abf55b6918d860e44b6eda2227393
n:7168, k:4608, hash:2162a9ac1ef3ec57bcb61026f14a22c1
n:8192, k:512, hash:178c4989723bf2cf6b216774fe798a1f
n:9216, k:7168, hash:d7b3bb3a047b35716ae4dc8a3f73b678For original file, check the attachment result.log. And the code diff to print out the result.log is --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
+++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
@@ -8,6 +8,7 @@
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
+#include "../../utils/hash.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
@@ -142,6 +143,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
+ std::cout << fmt::format("m:{}, n:{}, k:{}, hash:{}", m, n, k, get_hex_digest(code)) << std::endl;
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
} |
|
@Alcanderian just curious whats the status on this? :) |
* Refactor SM100 files * Make SM90 work * Minor fix * Lint * Minor fix --------- Co-authored-by: Zhean Xu <xza@deepseek.com>
No description provided.