豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit f844541

Browse files
authored
Merge pull request #22 from Fridge003/refactor-draft
[Draft] Add tvm-ffi support
2 parents 212b900 + 36ece13 commit f844541

31 files changed

+1062
-249
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ deep_gemm/include/cutlass
2121
stubs/
2222

2323
# Symlinks to compiled extensions
24-
deep_gemm/*.so
24+
deep_gemm/*.so
25+
deep_gemm/_C_build

CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON)
66
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
77
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
88
set(CUDA_SEPARABLE_COMPILATION ON)
9+
910
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
1011
list(APPEND CUDA_NVCC_FLAGS "-O3")
1112
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
@@ -17,13 +18,14 @@ set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
1718
find_package(CUDAToolkit REQUIRED)
1819
find_package(pybind11 REQUIRED)
1920
find_package(Torch REQUIRED)
21+
find_package(tvm_ffi REQUIRED)
2022

2123
set(CMAKE_CXX_STANDARD 17)
2224
set(CMAKE_CUDA_STANDARD 17)
2325

2426
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
25-
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
26-
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
27+
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${tvm_ffi_INCLUDE_DIR} ${tvm_ffi_DLPACK_INCLUDE_DIR})
28+
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs ${tvm_ffi_ROOT_DIR}/lib)
2729

2830
# The main Python API entrance
2931
pybind11_add_module(_C csrc/python_api.cpp)

build.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ original_dir=$(pwd)
33
script_dir=$(realpath "$(dirname "$0")")
44
cd "$script_dir"
55

6+
# Link CUTLASS includes
7+
ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include
8+
ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include
9+
610
# Remove old dist file, build files, and install
711
rm -rf build dist
812
rm -rf *.egg-info

csrc/apis/attention.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
255255

256256
#endif
257257

258+
#if 0
259+
258260
static void register_apis(pybind11::module_& m) {
259261
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
260262
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
@@ -276,4 +278,6 @@ static void register_apis(pybind11::module_& m) {
276278
#endif
277279
}
278280

281+
#endif
282+
279283
} // namespace deep_gemm::attention

csrc/apis/einsum.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#pragma once
22

3-
#include <pybind11/pybind11.h>
4-
#include <torch/python.h>
5-
63
#include "../utils/exception.hpp"
74
#include "../utils/format.hpp"
85
#include "../utils/layout.hpp"
@@ -214,6 +211,8 @@ static void fp8_einsum(const std::string& expr,
214211
}
215212
#endif
216213

214+
#if 0
215+
217216
static void register_apis(pybind11::module_& m) {
218217
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
219218
m.def("einsum", &einsum,
@@ -227,4 +226,6 @@ static void register_apis(pybind11::module_& m) {
227226
#endif
228227
}
229228

229+
#endif
230+
230231
} // namespace deep_gemm::einsum

csrc/apis/gemm.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,8 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
608608
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
609609
}
610610

611+
#if 0
612+
611613
static void register_apis(pybind11::module_& m) {
612614

613615
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
@@ -725,4 +727,6 @@ static void register_apis(pybind11::module_& m) {
725727
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
726728
}
727729

730+
#endif
731+
728732
} // namespace deep_gemm::gemm

csrc/apis/hyperconnection.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
5959

6060
#endif
6161

62+
#if 0
63+
6264
static void register_apis(pybind11::module_& m) {
6365
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
6466
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
@@ -67,4 +69,6 @@ static void register_apis(pybind11::module_& m) {
6769
#endif
6870
}
6971

72+
#endif
73+
7074
} // namespace deep_gemm::hyperconnection

csrc/apis/layout.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
9999

100100
#endif
101101

102+
#if 0
103+
102104
static void register_apis(pybind11::module_& m) {
103105

104106
#if DG_TENSORMAP_COMPATIBLE
@@ -117,4 +119,6 @@ static void register_apis(pybind11::module_& m) {
117119
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
118120
}
119121

122+
#endif
123+
120124
} // namespace deep_gemm::layout

csrc/apis/runtime.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
namespace deep_gemm::runtime {
99

10+
#if 0
11+
1012
static void register_apis(pybind11::module_& m) {
1113
m.def("set_num_sms", [&](const int& new_num_sms) {
1214
device_runtime->set_num_sms(new_num_sms);
@@ -34,4 +36,6 @@ static void register_apis(pybind11::module_& m) {
3436
});
3537
}
3638

39+
#endif
40+
3741
} // namespace deep_gemm::runtime

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <cuda.h>
4-
#include <torch/python.h>
4+
#include <torch/torch.h>
55

66
#include "../heuristics/sm90.hpp"
77
#include "../../jit/handle.hpp"

0 commit comments

Comments
 (0)