豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 67c7e2b

Browse files
authored
Maintain symmetric offset diff instead of the whole pointer (#239)
1 parent 8fc8efc commit 67c7e2b

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

csrc/apis/mega.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ static void fp8_fp4_mega_moe(
124124
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_,
125125
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_,
126126
const torch::Tensor& sym_buffer,
127-
const std::vector<uint64_t>& sym_buffer_ptrs, const int& rank_idx,
127+
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
128128
const int& num_max_tokens_per_rank,
129129
const int& num_experts, const int& num_topk,
130130
const std::tuple<int, int, int>& recipe,

csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void sm100_fp8_fp4_mega_moe(
107107
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
108108
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
109109
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
110-
const std::vector<uint64_t>& sym_buffer_ptrs,
110+
const std::vector<int64_t>& sym_buffer_ptrs,
111111
const int& rank_idx, const int& num_max_tokens_per_rank,
112112
const int& num_experts_per_rank,
113113
const int& num_tokens, const int& num_topk,

deep_gemm/include/deep_gemm/layout/sym_buffer.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,31 @@ constexpr static uint32_t kNumMaxRanks = 64;
88

99
template <uint32_t kNumRanks = kNumMaxRanks>
1010
struct SymBuffer {
11-
uint64_t offsets[kNumMaxRanks];
12-
13-
uint32_t rank_idx = 0;
11+
int64_t base;
12+
int64_t offsets[kNumMaxRanks];
13+
uint32_t rank_idx;
1414

1515
DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks");
1616

1717
SymBuffer() = default;
1818

1919
template <typename Container>
20-
explicit SymBuffer(const Container& c, const uint32_t& rank_idx = 0): rank_idx(rank_idx) {
20+
explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) {
2121
const auto size = static_cast<uint32_t>(c.size());
22+
base = c[rank_idx];
2223
for (uint32_t i = 0; i < kNumMaxRanks; ++ i)
23-
offsets[i] = i < size ? c[i] : 0;
24+
offsets[i] = i < size ? (c[i] - base) : 0;
2425
}
2526

2627
#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
2728
template <typename ptr_t = void*>
2829
CUTLASS_DEVICE ptr_t get_base_ptr() const {
29-
return reinterpret_cast<ptr_t>(offsets[rank_idx]);
30+
return reinterpret_cast<ptr_t>(base);
3031
}
3132

3233
template <typename ptr_t>
3334
CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const {
34-
uint64_t mapped_ptr = offsets[dst_rank_idx] +
35-
(reinterpret_cast<uint64_t>(ptr) - offsets[rank_idx]);
35+
int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr);
3636
return *reinterpret_cast<ptr_t*>(&mapped_ptr);
3737
}
3838
#endif

0 commit comments

Comments
 (0)