豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 4069fbe

Browse files
LyricZhaozheanxu
andauthored
SM100 Mega MoE (MMA part 2) (#183)
* Minor fix * Fix workspace API usages * Minor fix * Successful compilation * Align tokens * Dump profiling traces * Load token into shared memory * Store into remote buffers * Use 2 GiB to debug * Mega MoE Scheduler Update (#184) * Update scheduler * Fix ptx * Add scheduler init * Minor update * Update scheduler * Minor fixes * Refactor specialized ld/st PTX * Use scheduler in the kernel * Allocate slots indices together * Rename scheduler namespace * Process the last token block count * Fix bugs * Fix __fns * Fix dispatch and shceduler * Code lint * Add reference and fix CUDA 13 compilation --------- Co-authored-by: Zhean Xu <94977922+zheanxu@users.noreply.github.com> Co-authored-by: Zhean Xu <xza@deepseek.com>
1 parent 4fdadc8 commit 4069fbe

26 files changed

+647
-291
lines changed

CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ cmake_minimum_required(VERSION 3.10)
33
project(deep_gemm LANGUAGES CXX CUDA)
44
set(CMAKE_VERBOSE_MAKEFILE ON)
55

6-
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
7-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
6+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations")
87
set(CUDA_SEPARABLE_COMPILATION ON)
98
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
109
list(APPEND CUDA_NVCC_FLAGS "-O3")
@@ -22,7 +21,7 @@ set(CMAKE_CXX_STANDARD 17)
2221
set(CMAKE_CUDA_STANDARD 17)
2322

2423
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
25-
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
24+
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
2625
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
2726

2827
# The main Python API entrance

csrc/apis/mega.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ static int64_t get_symm_buffer_size_for_mega_moe(
1313
const int& hidden, const int& intermediate_hidden,
1414
const bool& use_fp8_dispatch, const std::string& activation) {
1515
// TODO: implement
16-
return 4096;
16+
// Currently, we use 16 GiB to debug
17+
return 16ll * 1024ll * 1024ll * 1024ll;
1718
}
1819

1920
static void fp8_fp4_mega_moe(
20-
const std::tuple<torch::Tensor, torch::Tensor>& hidden_states_,
21+
const std::tuple<torch::Tensor, torch::Tensor>& x_,
22+
const torch::Tensor& y,
2123
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_,
2224
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_,
2325
const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
@@ -26,7 +28,7 @@ static void fp8_fp4_mega_moe(
2628
const int& num_max_tokens_per_rank,
2729
const std::tuple<int, int, int>& recipe,
2830
const std::string& activation) {
29-
const auto [hidden_states, hidden_states_sf] = hidden_states_;
31+
const auto [x, x_sf] = x_;
3032
const auto [l1_weights, l1_weights_sf] = l1_weights_;
3133
const auto [l2_weights, l2_weights_sf] = l2_weights_;
3234

@@ -42,7 +44,7 @@ static void fp8_fp4_mega_moe(
4244
// Dispatch into different architectures
4345
const auto arch_major = device_runtime->get_arch_major();
4446
if (arch_major == 10) {
45-
sm100_fp8_fp4_mega_moe(hidden_states, hidden_states_sf,
47+
sm100_fp8_fp4_mega_moe(x, x_sf, y,
4648
l1_weights, l1_weights_sf,
4749
l2_weights, l2_weights_sf,
4850
topk_idx, topk_weights,

csrc/jit/handle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ static void* get_driver_handle() {
2424
#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \
2525
template <typename... Args> \
2626
static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \
27-
using FuncType = decltype(&name); \
27+
using FuncType = decltype(&(name)); \
2828
static FuncType func = nullptr; \
2929
if (func == nullptr) { \
3030
func = reinterpret_cast<FuncType>(dlsym(get_driver_handle(), #name)); \

csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime<SM100FP8FP4MegaMoER
2525
int num_ranks;
2626

2727
// Runtime arguments
28-
int num_tokens;
28+
void* x;
2929
int64_t* topk_idx;
30+
int num_tokens;
3031
layout::SymBuffer<> sym_buffer_ptrs;
3132
int rank_idx;
3233

@@ -64,18 +65,20 @@ static void __instantiate_kernel() {{
6465
// TODO: optimize `args` copy
6566
// TODO: tensor maps are missing
6667
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
67-
args.num_tokens, args.topk_idx, args.sym_buffer_ptrs, args.rank_idx));
68+
args.x, args.topk_idx,
69+
args.num_tokens,
70+
args.sym_buffer_ptrs, args.rank_idx));
6871
}
6972
};
7073

7174
static void sm100_fp8_fp4_mega_moe(
72-
const torch::Tensor& hidden_states, const torch::Tensor& hidden_states_sf,
75+
const torch::Tensor& x, const torch::Tensor& x_sf, const torch::Tensor& y,
7376
const torch::Tensor& l1_weights, const torch::Tensor& l1_weights_sf,
7477
const torch::Tensor& l2_weights, const torch::Tensor& l2_weights_sf,
7578
const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
7679
const std::vector<uint64_t>& sym_buffer_ptrs, const int& rank_idx,
7780
const int& num_max_tokens_per_rank) {
78-
const auto [num_tokens, hidden] = get_shape<2>(hidden_states);
81+
const auto [num_tokens, hidden] = get_shape<2>(x);
7982
const auto [num_experts_per_rank, intermediate_hidden, _] = get_shape<3>(l2_weights);
8083
const auto [__, num_topk] = get_shape<2>(topk_idx);
8184
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
@@ -92,11 +95,12 @@ static void sm100_fp8_fp4_mega_moe(
9295
.num_stages = 5,
9396
.num_dispatch_threads = 128, .num_mma_non_epilogue_threads = 128, .num_mma_epilogue_threads = 128,
9497
.num_ranks = num_ranks,
95-
.num_tokens = num_tokens,
98+
.x = x.data_ptr(),
9699
.topk_idx = topk_idx.data_ptr<int64_t>(),
100+
.num_tokens = num_tokens,
97101
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs),
98102
.rank_idx = rank_idx,
99-
.launch_args = LaunchArgs(num_sms, 256, 16384, 2)
103+
.launch_args = LaunchArgs(num_sms, 256, 232448, 2)
100104
};
101105
const auto code = SM100FP8FP4MegaMoERuntime::generate(args);
102106
const auto runtime = compiler->build("sm100_fp8_fp4_mega_moe", code);

csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQA
3131
using namespace deep_gemm;
3232
3333
static void __instantiate_kernel() {{
34-
auto ptr = reinterpret_cast<void*>(&scheduler::smxx_paged_mqa_logits_metadata<
34+
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
3535
{}, {}, {}
3636
>);
3737
}};

deep_gemm/include/deep_gemm/common/math.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ __device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
2323
return (a + b - 1) / b;
2424
}
2525

26-
template <typename T>
27-
__device__ __host__ T align(T a, T b) {
28-
return ceil_div(a, b) * b;
26+
template <typename T, bool kDoCeilAlignment = true>
27+
__forceinline__ __device__ __host__ T align(T a, T b) {
28+
return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b;
2929
}
3030

3131
template <typename T>

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
166166

167167
// Block scheduler
168168
uint32_t m_block_idx, n_block_idx;
169-
auto scheduler = scheduler::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
169+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
170170
shape_m, shape_n, shape_k, grouped_layout);
171171

172172
// Pipeline and TMA phases
@@ -195,19 +195,19 @@ sm100_bf16_gemm_impl(int* grouped_layout,
195195

196196
// Compute offsets
197197
// NOTES: the group is always concatenated with the outer dimension
198-
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), scheduler::IndexType::MN> (
198+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
199199
shape_m, BLOCK_M, m_block_idx);
200-
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), scheduler::IndexType::MN> (
200+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
201201
shape_n, BLOCK_N, n_block_idx, m_block_idx);
202202

203203
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
204204
// And for all m-grouped GEMMs, A must be K-majored
205205
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
206206
kMajorA == cute::UMMA::Major::K, "Invalid major");
207207
uint32_t k_idx = k_block_idx * BLOCK_K;
208-
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), scheduler::IndexType::K> (
208+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
209209
shape_k, BLOCK_K, k_block_idx, m_block_idx);
210-
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), scheduler::IndexType::K> (
210+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
211211
shape_k, BLOCK_K, k_block_idx, m_block_idx);
212212

213213
// Add 2 CTA offsets
@@ -384,7 +384,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
384384
// Load from tensor memory into registers, and write shared memory with STSM
385385
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
386386
const auto base_m_idx = scheduler.template get_global_idx<
387-
(not is_m_grouped_contiguous(kGemmType)), scheduler::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
387+
(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
388388
const auto base_n_idx = n_block_idx * BLOCK_N;
389389

390390
if constexpr (kSwapAB) {

deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
152152

153153
// Scheduler
154154
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
155-
auto scheduler = scheduler::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(
155+
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(
156156
batch_size, blockIdx.x, context_lens, schedule_meta);
157157
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
158158

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
180180

181181
// Block scheduler
182182
uint32_t m_block_idx, n_block_idx;
183-
auto scheduler = scheduler::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
183+
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
184184
shape_m, shape_n, shape_k, grouped_layout);
185185

186186
// Pipeline and TMA phases
@@ -209,19 +209,19 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
209209

210210
// Compute offsets
211211
// NOTES: the group is always concatenated with the outer dimension
212-
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), scheduler::IndexType::MN> (
212+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
213213
shape_m, BLOCK_M, m_block_idx);
214-
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), scheduler::IndexType::MN> (
214+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
215215
shape_n, BLOCK_N, n_block_idx, m_block_idx);
216216

217217
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
218218
// And for all m-grouped GEMMs, A must be K-majored
219219
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
220220
kMajorA == cute::UMMA::Major::K, "Invalid major");
221221
uint32_t k_idx = k_block_idx * BLOCK_K;
222-
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), scheduler::IndexType::K> (
222+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
223223
shape_k, BLOCK_K, k_block_idx, m_block_idx);
224-
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), scheduler::IndexType::K> (
224+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
225225
shape_k, BLOCK_K, k_block_idx, m_block_idx);
226226

227227
// Add 2 CTA offsets
@@ -252,14 +252,14 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
252252
// No swizzling, so one TMA for one SF is enough
253253
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
254254
uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
255-
uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), scheduler::IndexType::SF_K>(
255+
uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>(
256256
shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad));
257257
tma::copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
258258
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
259259
}
260260
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
261261
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
262-
uint32_t sfb_k_idx = scheduler.template get_global_idx<true, scheduler::IndexType::SF_K>(
262+
uint32_t sfb_k_idx = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(
263263
shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx);
264264
tma::copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
265265
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
@@ -460,7 +460,7 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
460460
ptx::tcgen05_after_thread_sync();
461461

462462
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
463-
const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), scheduler::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
463+
const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
464464
const auto base_n_idx = n_block_idx * BLOCK_N;
465465

466466
if constexpr (kSwapAB) {

0 commit comments

Comments
 (0)