豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit eda35b4

Browse files
zhyncsFlamingoPg
andauthored
update (#3)
Co-authored-by: PGFLMG <1106310035@qq.com>
1 parent 2da871e commit eda35b4

File tree

16 files changed

+476
-176
lines changed

16 files changed

+476
-176
lines changed

csrc/apis/gemm.hpp

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -403,69 +403,4 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
403403
}
404404
}
405405

406-
static void register_apis(pybind11::module_& m) {
407-
// FP8 GEMMs
408-
m.def("fp8_gemm_nt", &fp8_gemm_nt,
409-
py::arg("a"), py::arg("b"), py::arg("d"),
410-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
411-
py::arg("compiled_dims") = "nk",
412-
py::arg("disable_ue8m0_cast") = false);
413-
m.def("fp8_gemm_nn", &fp8_gemm_nn,
414-
py::arg("a"), py::arg("b"), py::arg("d"),
415-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
416-
py::arg("compiled_dims") = "nk",
417-
py::arg("disable_ue8m0_cast") = false);
418-
m.def("fp8_gemm_tn", &fp8_gemm_tn,
419-
py::arg("a"), py::arg("b"), py::arg("d"),
420-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
421-
py::arg("compiled_dims") = "mn",
422-
py::arg("disable_ue8m0_cast") = false);
423-
m.def("fp8_gemm_tt", &fp8_gemm_tt,
424-
py::arg("a"), py::arg("b"), py::arg("d"),
425-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
426-
py::arg("compiled_dims") = "mn",
427-
py::arg("disable_ue8m0_cast") = false);
428-
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
429-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
430-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
431-
py::arg("disable_ue8m0_cast") = false);
432-
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
433-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
434-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
435-
py::arg("disable_ue8m0_cast") = false);
436-
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
437-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
438-
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
439-
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
440-
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
441-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
442-
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
443-
py::arg("recipe") = std::make_tuple(1, 1, 128),
444-
py::arg("compiled_dims") = "mn");
445-
446-
// BF16 GEMMs
447-
m.def("bf16_gemm_nt", &bf16_gemm_nt,
448-
py::arg("a"), py::arg("b"), py::arg("d"),
449-
py::arg("c") = std::nullopt,
450-
py::arg("compiled_dims") = "nk");
451-
m.def("bf16_gemm_nn", &bf16_gemm_nn,
452-
py::arg("a"), py::arg("b"), py::arg("d"),
453-
py::arg("c") = std::nullopt,
454-
py::arg("compiled_dims") = "nk");
455-
m.def("bf16_gemm_tn", &bf16_gemm_tn,
456-
py::arg("a"), py::arg("b"), py::arg("d"),
457-
py::arg("c") = std::nullopt,
458-
py::arg("compiled_dims") = "mn");
459-
m.def("bf16_gemm_tt", &bf16_gemm_tt,
460-
py::arg("a"), py::arg("b"), py::arg("d"),
461-
py::arg("c") = std::nullopt,
462-
py::arg("compiled_dims") = "mn");
463-
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
464-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
465-
py::arg("compiled_dims") = "nk");
466-
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
467-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
468-
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
469-
}
470-
471-
} // namespace deep_gemm::gemm
406+
} // namespace deep_gemm::gemm

csrc/apis/layout.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,4 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
6969
DG_HOST_UNREACHABLE("Unknown cases");
7070
}
7171

72-
static void register_apis(pybind11::module_& m) {
73-
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
74-
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
75-
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
76-
py::arg("disable_ue8m0_cast") = false);
77-
78-
m.def("get_tma_aligned_size", &get_tma_aligned_size);
79-
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
80-
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
81-
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
82-
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
83-
}
84-
8572
} // namespace deep_gemm::layout

csrc/apis/runtime.hpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,6 @@
55

66
namespace deep_gemm::runtime {
77

8-
static void register_apis(pybind11::module_& m) {
9-
m.def("set_num_sms", [&](const int& new_num_sms) {
10-
device_runtime->set_num_sms(new_num_sms);
11-
});
12-
m.def("get_num_sms", [&]() {
13-
return device_runtime->get_num_sms();
14-
});
15-
m.def("set_tc_util", [&](const int& new_tc_util) {
16-
device_runtime->set_tc_util(new_tc_util);
17-
});
18-
m.def("get_tc_util", [&]() {
19-
return device_runtime->get_tc_util();
20-
});
21-
22-
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
23-
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
24-
KernelRuntime::prepare_init(cuda_home_path_by_python);
25-
});
26-
}
8+
// The init and other functions are now exposed via TORCH_LIBRARY in python_api.cpp
279

2810
} // namespace deep_gemm::runtime

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <cuda.h>
4-
#include <torch/python.h>
54

65
#include "../../utils/math.hpp"
76
#include "../../utils/exception.hpp"

csrc/jit_kernels/impls/sm100_bf16_gemm.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm90_bf16_gemm.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/kernel_runtime.hpp"
75
#include "../../utils/exception.hpp"

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/smxx_layout.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/kernel_runtime.hpp"
64
#include "../../utils/exception.hpp"
75
#include "../../utils/format.hpp"

0 commit comments

Comments
 (0)