豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 2b71c00

Browse files
committed
Remove a redundant register
1 parent 67c7e2b commit 2b71c00

File tree

4 files changed

+6
-9
lines changed

4 files changed

+6
-9
lines changed

csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime<SM100FP8FP4MegaMoER
3131
void* y;
3232
int num_tokens;
3333
layout::SymBuffer<> sym_buffer_ptrs;
34-
int rank_idx;
3534

3635
// Tensormap
3736
CUtensorMap tensor_map_l1_acts;
@@ -87,7 +86,7 @@ static void __instantiate_kernel() {{
8786
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
8887
args.y,
8988
args.num_tokens,
90-
args.sym_buffer_ptrs, args.rank_idx,
89+
args.sym_buffer_ptrs,
9190
args.tensor_map_l1_acts,
9291
args.tensor_map_l1_acts_sf,
9392
args.tensor_map_l1_weights,
@@ -185,7 +184,6 @@ static void sm100_fp8_fp4_mega_moe(
185184
.y = y.data_ptr(),
186185
.num_tokens = num_tokens,
187186
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
188-
.rank_idx = rank_idx,
189187
.tensor_map_l1_acts = tensor_map_l1_acts,
190188
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
191189
.tensor_map_l1_weights = tensor_map_l1_weights,

deep_gemm/include/deep_gemm/comm/barrier.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace,
5757
const auto start_clock = clock64();
5858
while (ptx::ld_acq_sys(signal_ptr) != target) {
5959
if (clock64() - start_clock >= kNumTimeoutCycles) {
60-
printf("DeepGEMM NVLink barrier timeout (30s): signal=%d, target=%d, phase=%d, sign=%d\n",
61-
ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign);
60+
printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, signal=%d, target=%d, phase=%d, sign=%d\n",
61+
sym_buffer.rank_idx, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign);
6262
DG_DEVICE_ASSERT(false and "NVLink barrier timeout");
6363
}
6464
}

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
4848
sm100_fp8_fp4_mega_moe_impl(void* y,
4949
const uint32_t num_tokens,
5050
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
51-
const uint32_t rank_idx,
5251
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
5352
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf,
5453
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights,
@@ -390,7 +389,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
390389
const auto dst_rank_idx = expert_idx / kNumExpertsPerRank;
391390
const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1);
392391
const auto dst_ptr = workspace.get_src_token_topk_idx_ptr(
393-
expert_idx % kNumExpertsPerRank, rank_idx, dst_slot_idx);
392+
expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx);
394393
*sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx;
395394
});
396395
cutlass::arch::NamedBarrier::sync(kNumDispatchThreads, kDispatchBarrierIdx);
@@ -409,7 +408,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
409408
const auto dst_local_expert_idx = i % kNumExpertsPerRank;
410409
const auto expert_status = *workspace.get_expert_send_count_ptr(i);
411410
*sym_buffer.map(
412-
workspace.get_expert_recv_count_ptr(rank_idx, dst_local_expert_idx),
411+
workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx),
413412
dst_rank_idx) = expert_status & 0xffffffff;
414413
ptx::atomic_add_sys(
415414
sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx),

deep_gemm/include/deep_gemm/layout/sym_buffer.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace deep_gemm::layout {
66

7-
constexpr static uint32_t kNumMaxRanks = 64;
7+
constexpr static uint32_t kNumMaxRanks = 72;
88

99
template <uint32_t kNumRanks = kNumMaxRanks>
1010
struct SymBuffer {

0 commit comments

Comments
 (0)