豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 88398ce

Browse files
LyricZhaozheanxu
andauthored
Refactor device file structures (#182)
* Refactor SM100 files * Make SM90 work * Minor fix * Lint * Minor fix --------- Co-authored-by: Zhean Xu <xza@deepseek.com>
1 parent 912c3ca commit 88398ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1753
-1612
lines changed

csrc/apis/attention.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
66
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
77
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
8-
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
8+
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
99
#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp"
1010
#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp"
1111
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
@@ -64,7 +64,7 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
6464

6565
// Dispatch into different implements
6666
const auto arch_major = device_runtime->get_arch_major();
67-
const auto epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
67+
const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
6868
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
6969
const auto major_sfb = get_major_type_ab(sfb);
7070
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);

csrc/apis/gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
77
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
88
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
9-
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
9+
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
1010
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
1111
#endif
1212

csrc/indexing/main.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
44
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
55
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
6-
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
6+
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
77

88
// Attention kernels
99
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
1010
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
11+
#include <deep_gemm/impls/sm100_fp4_mqa_logits.cuh>
1112
#include <deep_gemm/impls/sm100_fp8_mqa_logits.cuh>
13+
#include <deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh>
1214
#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh>
1315
/* oss-ignore-begin */
1416
#include <deep_gemm/impls/sm100_sparse_mqa_logits_bwd.cuh>

csrc/jit_kernels/heuristics/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <unordered_set>
4-
#include <deep_gemm/common/types.hpp>
4+
#include <deep_gemm/common/types.cuh>
55

66
#include "config.hpp"
77
#include "runtime.hpp"

csrc/jit_kernels/heuristics/config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <cute/arch/mma_sm100_desc.hpp>
44
#include <c10/core/ScalarType.h>
5-
#include <deep_gemm/common/types.hpp>
5+
#include <deep_gemm/common/types.cuh>
66

77
#include "../../utils/math.hpp"
88

csrc/jit_kernels/heuristics/sm100.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <cute/arch/mma_sm100_desc.hpp>
44
// Reuse some types in the JIT modules
5-
#include <deep_gemm/common/types.hpp>
5+
#include <deep_gemm/common/types.cuh>
66

77
#include "common.hpp"
88
#include "runtime.hpp"

csrc/jit_kernels/heuristics/sm90.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <cute/arch/mma_sm100_desc.hpp>
44
// Reuse some types in the JIT modules
5-
#include <deep_gemm/common/types.hpp>
5+
#include <deep_gemm/common/types.cuh>
66

77
#include "common.hpp"
88
#include "utils.hpp"

csrc/jit_kernels/heuristics/utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <cute/arch/mma_sm100_desc.hpp>
44
// Reuse some types in the JIT modules
5-
#include <deep_gemm/common/types.hpp>
5+
#include <deep_gemm/common/types.cuh>
66

77
#include "common.hpp"
88
#include "../../utils/exception.hpp"

csrc/jit_kernels/impls/epilogue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
namespace deep_gemm {
77

88
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
9-
return epilogue_type.value_or("EpilogueIdentity");
9+
return epilogue_type.value_or("epilogue::transform::EpilogueIdentity");
1010
}
1111

1212
} // namespace deep_gemm

csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp renamed to csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1D
3838
static std::string generate_impl(const Args& args) {
3939
// TODO: rename files
4040
return fmt::format(R"(
41-
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
41+
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
4242
4343
using namespace deep_gemm;
4444
4545
static void __instantiate_kernel() {{
46-
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
46+
auto ptr = reinterpret_cast<void*>(&sm100_fp8_fp4_gemm_1d1d_impl<
4747
{}, {},
4848
{}, {},
4949
{}, {}, {},

0 commit comments

Comments
 (0)