豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 6bef332

Browse files
zheanxuLyricZhao
andauthored
Mega MoE Scaling Factors (#227)
* Support SFB * Minor fix * Minor fix * Simplify * Simplify * Support SFA * Optimize * Minor fix * Optimize * Minor fix * Minor fix * Minor fix * Minor fix * Minor fix * Minor fix * Better overlapping * Change block size * Refactor weight interleaving * More stages * Print 2-digit speedup * Lint --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent 3f05a98 commit 6bef332

File tree

9 files changed

+324
-110
lines changed

9 files changed

+324
-110
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ find_package(CUDAToolkit REQUIRED)
1717
find_package(pybind11 REQUIRED)
1818
find_package(Torch REQUIRED)
1919

20-
set(CMAKE_CXX_STANDARD 17)
21-
set(CMAKE_CUDA_STANDARD 17)
20+
set(CMAKE_CXX_STANDARD 20)
21+
set(CMAKE_CUDA_STANDARD 20)
2222

2323
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
2424
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})

csrc/apis/mega.hpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace deep_gemm::mega {
1313

14-
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
14+
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
1515
get_symm_buffer_size_for_mega_moe(
1616
const int& num_ranks, const int& num_experts,
1717
const int& num_max_tokens_per_rank, const int& num_topk,
@@ -27,6 +27,8 @@ get_symm_buffer_size_for_mega_moe(
2727
const auto fp8_token_layout = layout::Data(hidden);
2828
const auto bf16_token_layout = layout::Data(hidden * 2);
2929
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
30+
const auto fp8_sf_layout = layout::Data(hidden / 32);
31+
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
3032
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
3133
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
3234
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
@@ -35,10 +37,12 @@ get_symm_buffer_size_for_mega_moe(
3537
const auto input_token_buffer = layout::Buffer(
3638
fp8_token_layout, 1, num_max_tokens_per_rank,
3739
workspace.get_end_ptr());
38-
// TODO: add `input_sf_buffer`
40+
const auto input_sf_buffer = layout::Buffer(
41+
fp8_sf_layout, 1, num_max_tokens_per_rank,
42+
input_token_buffer.get_end_ptr());
3943
const auto input_topk_idx_buffer = layout::Buffer(
4044
input_topk_idx_layout, 1, num_max_tokens_per_rank,
41-
input_token_buffer.get_end_ptr());
45+
input_sf_buffer.get_end_ptr());
4246
const auto input_topk_weights_buffer = layout::Buffer(
4347
input_topk_weights_layout, 1, num_max_tokens_per_rank,
4448
input_topk_idx_buffer.get_end_ptr());
@@ -49,29 +53,41 @@ get_symm_buffer_size_for_mega_moe(
4953
const auto l1_token_buffer = layout::Buffer(
5054
fp8_token_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
5155
input_topk_weights_buffer.get_end_ptr());
52-
// TODO: add `l1_input_sf_buffer`
56+
const auto l1_sf_buffer = layout::Buffer(
57+
fp8_sf_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
58+
l1_token_buffer.get_end_ptr());
5359
const auto l1_topk_weights_buffer = layout::Buffer(
5460
l1_topk_weights_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
55-
l1_token_buffer.get_end_ptr());
61+
l1_sf_buffer.get_end_ptr());
5662

5763
// L2 input buffer
5864
const auto l2_token_buffer = layout::Buffer(
5965
fp8_intermediate_token_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
6066
l1_topk_weights_buffer.get_end_ptr());
67+
const auto l2_sf_buffer = layout::Buffer(
68+
fp8_intermediate_sf_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
69+
l2_token_buffer.get_end_ptr());
6170

6271
// Combine input buffer: BF16 tokens for cross-rank combine
6372
const auto combine_token_buffer = layout::Buffer(
6473
bf16_token_layout, num_topk, num_max_tokens_per_rank,
65-
l2_token_buffer.get_end_ptr());
74+
l2_sf_buffer.get_end_ptr());
6675

67-
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l2_acts)` tensor views from the raw buffer
76+
// Check SF buffer requirements
77+
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
78+
DG_HOST_ASSERT(num_max_recv_tokens_per_expert % 4 == 0);
79+
80+
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
81+
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
6882
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
6983
auto x = torch::from_blob(
7084
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
7185
{num_max_tokens_per_rank, hidden},
7286
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
73-
// TODO: create `x_sf` from buffer
74-
auto x_sf = torch::empty(0, torch::TensorOptions().device(buffer.device()));
87+
auto x_sf = torch::from_blob(
88+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
89+
{num_max_tokens_per_rank, hidden / 128},
90+
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
7591
auto topk_idx = torch::from_blob(
7692
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
7793
{num_max_tokens_per_rank, num_topk},
@@ -84,11 +100,21 @@ get_symm_buffer_size_for_mega_moe(
84100
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
85101
{num_experts_per_rank * num_max_recv_tokens_per_expert, hidden},
86102
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
103+
auto l1_acts_sf = torch::from_blob(
104+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
105+
{num_max_recv_tokens_per_expert, hidden / 128 * num_experts_per_rank},
106+
{1, num_max_recv_tokens_per_expert},
107+
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
87108
auto l2_acts = torch::from_blob(
88109
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
89110
{num_experts_per_rank * num_max_recv_tokens_per_expert, intermediate_hidden},
90111
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
91-
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l2_acts);
112+
auto l2_acts_sf = torch::from_blob(
113+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
114+
{num_max_recv_tokens_per_expert, intermediate_hidden / 128 * num_experts_per_rank},
115+
{1, num_max_recv_tokens_per_expert},
116+
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
117+
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
92118
};
93119
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
94120
}
@@ -133,6 +159,13 @@ static void fp8_fp4_mega_moe(
133159
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
134160
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
135161

162+
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
163+
constexpr int kGranMN = 1, kGranK = 32;
164+
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
165+
num_experts_per_rank, true, false, torch::kInt);
166+
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
167+
num_experts_per_rank, true, false, torch::kInt);
168+
136169
// Check buffer bytes
137170
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
138171
const auto num_experts_ = num_experts_per_rank * num_ranks;
@@ -145,12 +178,13 @@ static void fp8_fp4_mega_moe(
145178
DG_HOST_ASSERT(num_experts == num_experts_);
146179

147180
// Already registered tensors
148-
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l2_acts] = slice(sym_buffer);
181+
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
149182

150183
// Dispatch into different architectures
151184
if (arch_major == 10) {
152185
sm100_fp8_fp4_mega_moe(y,
153-
l1_acts, l2_acts,
186+
l1_acts, l1_acts_sf,
187+
l2_acts, l2_acts_sf,
154188
l1_weights, l2_weights,
155189
l1_weights_sf, l2_weights_sf,
156190
sym_buffer_ptrs,

csrc/jit_kernels/heuristics/mega_moe.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,13 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
117117
// Tensor memory pointer
118118
const int smem_tmem_ptr = 4;
119119

120-
// Per-stage: A tile + B tile + full/empty barriers
121-
const int smem_per_stage = load_block_m * block_k + block_n * block_k + 2 * 8;
120+
// SF is aligned to UTCCP 128-element granularity
121+
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
122+
const int smem_sfa_per_stage = sf_block_m * 4;
123+
const int smem_sfb_per_stage = sf_block_n * 4;
124+
125+
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty/with_sf_full barriers
126+
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 3 * 8;
122127

123128
// Fixed total
124129
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
@@ -140,7 +145,7 @@ static MegaMoEConfig get_mega_moe_config(
140145
const int block_k = 128;
141146
const int load_block_m = block_m / 2;
142147
const int load_block_n = block_n;
143-
const int store_block_m = 48;
148+
const int store_block_m = 32;
144149
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
145150
const int swizzle_acts_mode = 128;
146151
const int swizzle_weights_mode = 128;

csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,14 @@ class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime<SM100FP8FP4MegaMoER
3535

3636
// Tensormap
3737
CUtensorMap tensor_map_l1_acts;
38+
CUtensorMap tensor_map_l1_acts_sf;
3839
CUtensorMap tensor_map_l1_weights;
40+
CUtensorMap tensor_map_l1_weights_sf;
3941
CUtensorMap tensor_map_l1_output;
4042
CUtensorMap tensor_map_l2_acts;
43+
CUtensorMap tensor_map_l2_acts_sf;
4144
CUtensorMap tensor_map_l2_weights;
45+
CUtensorMap tensor_map_l2_weights_sf;
4246

4347
// Launch configs
4448
LaunchArgs launch_args;
@@ -85,17 +89,22 @@ static void __instantiate_kernel() {{
8589
args.num_tokens,
8690
args.sym_buffer_ptrs, args.rank_idx,
8791
args.tensor_map_l1_acts,
92+
args.tensor_map_l1_acts_sf,
8893
args.tensor_map_l1_weights,
94+
args.tensor_map_l1_weights_sf,
8995
args.tensor_map_l1_output,
9096
args.tensor_map_l2_acts,
91-
args.tensor_map_l2_weights
97+
args.tensor_map_l2_acts_sf,
98+
args.tensor_map_l2_weights,
99+
args.tensor_map_l2_weights_sf
92100
));
93101
}
94102
};
95103

96104
static void sm100_fp8_fp4_mega_moe(
97105
const torch::Tensor& y,
98-
const torch::Tensor& l1_acts, const torch::Tensor& l2_acts,
106+
const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf,
107+
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
99108
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
100109
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
101110
const std::vector<uint64_t>& sym_buffer_ptrs,
@@ -116,17 +125,26 @@ static void sm100_fp8_fp4_mega_moe(
116125
num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden);
117126

118127
// Make tensormap
128+
constexpr int kGranK = 32;
119129
const auto num_max_recv_tokens = num_experts_per_rank * num_max_recv_tokens_per_expert;
120130
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
121131
hidden, num_max_recv_tokens,
122132
config.block_k, config.load_block_m,
123133
static_cast<int>(l1_acts.stride(-2)),
124134
config.swizzle_acts_mode);
135+
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
136+
num_max_recv_tokens_per_expert, hidden,
137+
config.block_m, kGranK,
138+
num_experts_per_rank, 0);
125139
const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights,
126140
hidden, num_experts_per_rank * intermediate_hidden * 2,
127141
config.block_k, config.load_block_n,
128142
static_cast<int>(l1_weights.stride(-2)),
129143
config.swizzle_weights_mode);
144+
const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf,
145+
intermediate_hidden * 2, hidden,
146+
config.block_n, kGranK,
147+
num_experts_per_rank, 0);
130148
// NOTES: L1 output and L2 activations are essentially the same tensor.
131149
// Post-SwiGLU output has half the N width (`BLOCK_N / 2` per input tile),
132150
// so the swizzle mode is also halved (128 -> 64).
@@ -140,11 +158,19 @@ static void sm100_fp8_fp4_mega_moe(
140158
config.block_k, config.load_block_m,
141159
static_cast<int>(l2_acts.stride(-2)),
142160
config.swizzle_acts_mode);
161+
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
162+
num_max_recv_tokens_per_expert, intermediate_hidden,
163+
config.block_m, kGranK,
164+
num_experts_per_rank, 0);
143165
const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights,
144166
intermediate_hidden, num_experts_per_rank * hidden,
145167
config.block_k, config.load_block_n,
146168
static_cast<int>(l2_weights.stride(-2)),
147169
config.swizzle_weights_mode);
170+
const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf,
171+
hidden, intermediate_hidden,
172+
config.block_n, kGranK,
173+
num_experts_per_rank, 0);
148174

149175
// Launch
150176
const auto num_sms = device_runtime->get_num_sms();
@@ -161,10 +187,14 @@ static void sm100_fp8_fp4_mega_moe(
161187
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
162188
.rank_idx = rank_idx,
163189
.tensor_map_l1_acts = tensor_map_l1_acts,
190+
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
164191
.tensor_map_l1_weights = tensor_map_l1_weights,
192+
.tensor_map_l1_weights_sf = tensor_map_l1_weights_sf,
165193
.tensor_map_l1_output = tensor_map_l1_output,
166194
.tensor_map_l2_acts = tensor_map_l2_acts,
195+
.tensor_map_l2_acts_sf = tensor_map_l2_acts_sf,
167196
.tensor_map_l2_weights = tensor_map_l2_weights,
197+
.tensor_map_l2_weights_sf = tensor_map_l2_weights_sf,
168198
.launch_args = LaunchArgs(num_sms,
169199
config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads,
170200
config.smem_size, 2)

0 commit comments

Comments
 (0)