豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 198d857

Browse files
authored
Merge branch 'sgl-kernel' into main
2 parents e8e9ea4 + b29302e commit 198d857

20 files changed

+488
-189
lines changed

csrc/apis/gemm.hpp

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -403,69 +403,4 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
403403
}
404404
}
405405

406-
static void register_apis(pybind11::module_& m) {
407-
// FP8 GEMMs
408-
m.def("fp8_gemm_nt", &fp8_gemm_nt,
409-
py::arg("a"), py::arg("b"), py::arg("d"),
410-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
411-
py::arg("compiled_dims") = "nk",
412-
py::arg("disable_ue8m0_cast") = false);
413-
m.def("fp8_gemm_nn", &fp8_gemm_nn,
414-
py::arg("a"), py::arg("b"), py::arg("d"),
415-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
416-
py::arg("compiled_dims") = "nk",
417-
py::arg("disable_ue8m0_cast") = false);
418-
m.def("fp8_gemm_tn", &fp8_gemm_tn,
419-
py::arg("a"), py::arg("b"), py::arg("d"),
420-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
421-
py::arg("compiled_dims") = "mn",
422-
py::arg("disable_ue8m0_cast") = false);
423-
m.def("fp8_gemm_tt", &fp8_gemm_tt,
424-
py::arg("a"), py::arg("b"), py::arg("d"),
425-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
426-
py::arg("compiled_dims") = "mn",
427-
py::arg("disable_ue8m0_cast") = false);
428-
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
429-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
430-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
431-
py::arg("disable_ue8m0_cast") = false);
432-
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
433-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
434-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
435-
py::arg("disable_ue8m0_cast") = false);
436-
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
437-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
438-
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
439-
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
440-
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
441-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
442-
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
443-
py::arg("recipe") = std::make_tuple(1, 1, 128),
444-
py::arg("compiled_dims") = "mn");
445-
446-
// BF16 GEMMs
447-
m.def("bf16_gemm_nt", &bf16_gemm_nt,
448-
py::arg("a"), py::arg("b"), py::arg("d"),
449-
py::arg("c") = std::nullopt,
450-
py::arg("compiled_dims") = "nk");
451-
m.def("bf16_gemm_nn", &bf16_gemm_nn,
452-
py::arg("a"), py::arg("b"), py::arg("d"),
453-
py::arg("c") = std::nullopt,
454-
py::arg("compiled_dims") = "nk");
455-
m.def("bf16_gemm_tn", &bf16_gemm_tn,
456-
py::arg("a"), py::arg("b"), py::arg("d"),
457-
py::arg("c") = std::nullopt,
458-
py::arg("compiled_dims") = "mn");
459-
m.def("bf16_gemm_tt", &bf16_gemm_tt,
460-
py::arg("a"), py::arg("b"), py::arg("d"),
461-
py::arg("c") = std::nullopt,
462-
py::arg("compiled_dims") = "mn");
463-
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
464-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
465-
py::arg("compiled_dims") = "nk");
466-
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
467-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
468-
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
469-
}
470-
471-
} // namespace deep_gemm::gemm
406+
} // namespace deep_gemm::gemm

csrc/apis/layout.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,4 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
6969
DG_HOST_UNREACHABLE("Unknown cases");
7070
}
7171

72-
static void register_apis(pybind11::module_& m) {
73-
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
74-
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
75-
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
76-
py::arg("disable_ue8m0_cast") = false);
77-
78-
m.def("get_tma_aligned_size", &get_tma_aligned_size);
79-
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
80-
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
81-
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
82-
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
83-
}
84-
8572
} // namespace deep_gemm::layout

csrc/apis/runtime.hpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,6 @@
55

66
namespace deep_gemm::runtime {
77

8-
static void register_apis(pybind11::module_& m) {
9-
m.def("set_num_sms", [&](const int& new_num_sms) {
10-
device_runtime->set_num_sms(new_num_sms);
11-
});
12-
m.def("get_num_sms", [&]() {
13-
return device_runtime->get_num_sms();
14-
});
15-
m.def("set_tc_util", [&](const int& new_tc_util) {
16-
device_runtime->set_tc_util(new_tc_util);
17-
});
18-
m.def("get_tc_util", [&]() {
19-
return device_runtime->get_tc_util();
20-
});
21-
22-
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
23-
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
24-
KernelRuntime::prepare_init(cuda_home_path_by_python);
25-
});
26-
}
8+
// The init and other functions are now exposed via TORCH_LIBRARY in python_api.cpp
279

2810
} // namespace deep_gemm::runtime

csrc/jit/device_runtime.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ class DeviceRuntime {
2525
return {prop->major, prop->minor};
2626
}
2727

28-
std::string get_arch() {
28+
int get_arch() {
2929
const auto& [major, minor] = get_arch_pair();
30-
if (major == 10 and minor != 1)
31-
return "100f";
32-
return std::to_string(major * 10 + minor) + "a";
30+
return major * 10 + minor;
3331
}
3432

3533
int get_arch_major() {

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <cuda.h>
4-
#include <torch/python.h>
54

65
#include "../../utils/math.hpp"
76
#include "../../utils/exception.hpp"

csrc/jit_kernels/impls/sm100_bf16_gemm.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm90_bf16_gemm.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/kernel_runtime.hpp"
75
#include "../../utils/exception.hpp"

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

0 commit comments

Comments
 (0)