1111
1212namespace deep_gemm ::mega {
1313
14- static std::tuple<int64_t , std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
14+ static std::tuple<int64_t , std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor >(const torch::Tensor&)>>
1515get_symm_buffer_size_for_mega_moe (
1616 const int & num_ranks, const int & num_experts,
1717 const int & num_max_tokens_per_rank, const int & num_topk,
@@ -27,6 +27,8 @@ get_symm_buffer_size_for_mega_moe(
2727 const auto fp8_token_layout = layout::Data (hidden);
2828 const auto bf16_token_layout = layout::Data (hidden * 2 );
2929 const auto fp8_intermediate_token_layout = layout::Data (intermediate_hidden);
30+ const auto fp8_sf_layout = layout::Data (hidden / 32 );
31+ const auto fp8_intermediate_sf_layout = layout::Data (intermediate_hidden / 32 );
3032 const auto input_topk_idx_layout = layout::Data (num_topk * sizeof (int64_t ), false );
3133 const auto input_topk_weights_layout = layout::Data (num_topk * sizeof (float ), false );
3234 const auto l1_topk_weights_layout = layout::Data (sizeof (float ), false );
@@ -35,10 +37,12 @@ get_symm_buffer_size_for_mega_moe(
3537 const auto input_token_buffer = layout::Buffer (
3638 fp8_token_layout, 1 , num_max_tokens_per_rank,
3739 workspace.get_end_ptr ());
38- // TODO: add `input_sf_buffer`
40+ const auto input_sf_buffer = layout::Buffer (
41+ fp8_sf_layout, 1 , num_max_tokens_per_rank,
42+ input_token_buffer.get_end_ptr ());
3943 const auto input_topk_idx_buffer = layout::Buffer (
4044 input_topk_idx_layout, 1 , num_max_tokens_per_rank,
41- input_token_buffer .get_end_ptr ());
45+ input_sf_buffer .get_end_ptr ());
4246 const auto input_topk_weights_buffer = layout::Buffer (
4347 input_topk_weights_layout, 1 , num_max_tokens_per_rank,
4448 input_topk_idx_buffer.get_end_ptr ());
@@ -49,29 +53,41 @@ get_symm_buffer_size_for_mega_moe(
4953 const auto l1_token_buffer = layout::Buffer (
5054 fp8_token_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
5155 input_topk_weights_buffer.get_end_ptr ());
52- // TODO: add `l1_input_sf_buffer`
56+ const auto l1_sf_buffer = layout::Buffer (
57+ fp8_sf_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
58+ l1_token_buffer.get_end_ptr ());
5359 const auto l1_topk_weights_buffer = layout::Buffer (
5460 l1_topk_weights_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
55- l1_token_buffer .get_end_ptr ());
61+ l1_sf_buffer .get_end_ptr ());
5662
5763 // L2 input buffer
5864 const auto l2_token_buffer = layout::Buffer (
5965 fp8_intermediate_token_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
6066 l1_topk_weights_buffer.get_end_ptr ());
67+ const auto l2_sf_buffer = layout::Buffer (
68+ fp8_intermediate_sf_layout, num_experts_per_rank, num_max_recv_tokens_per_expert,
69+ l2_token_buffer.get_end_ptr ());
6170
6271 // Combine input buffer: BF16 tokens for cross-rank combine
6372 const auto combine_token_buffer = layout::Buffer (
6473 bf16_token_layout, num_topk, num_max_tokens_per_rank,
65- l2_token_buffer .get_end_ptr ());
74+ l2_sf_buffer .get_end_ptr ());
6675
67- // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l2_acts)` tensor views from the raw buffer
76+ // Check SF buffer requirements
77+ DG_HOST_ASSERT (hidden % 128 == 0 and intermediate_hidden % 128 == 0 );
78+ DG_HOST_ASSERT (num_max_recv_tokens_per_expert % 4 == 0 );
79+
80+ // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
81+ // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
6882 auto slice_input_buffers = [=](const torch::Tensor& buffer) {
6983 auto x = torch::from_blob (
7084 math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(input_token_buffer.base )),
7185 {num_max_tokens_per_rank, hidden},
7286 torch::TensorOptions ().dtype (torch::kFloat8_e4m3fn ).device (buffer.device ()));
73- // TODO: create `x_sf` from buffer
74- auto x_sf = torch::empty (0 , torch::TensorOptions ().device (buffer.device ()));
87+ auto x_sf = torch::from_blob (
88+ math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(input_sf_buffer.base )),
89+ {num_max_tokens_per_rank, hidden / 128 },
90+ torch::TensorOptions ().dtype (torch::kInt ).device (buffer.device ()));
7591 auto topk_idx = torch::from_blob (
7692 math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(input_topk_idx_buffer.base )),
7793 {num_max_tokens_per_rank, num_topk},
@@ -84,11 +100,21 @@ get_symm_buffer_size_for_mega_moe(
84100 math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(l1_token_buffer.base )),
85101 {num_experts_per_rank * num_max_recv_tokens_per_expert, hidden},
86102 torch::TensorOptions ().dtype (torch::kFloat8_e4m3fn ).device (buffer.device ()));
103+ auto l1_acts_sf = torch::from_blob (
104+ math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(l1_sf_buffer.base )),
105+ {num_max_recv_tokens_per_expert, hidden / 128 * num_experts_per_rank},
106+ {1 , num_max_recv_tokens_per_expert},
107+ torch::TensorOptions ().dtype (torch::kInt ).device (buffer.device ()));
87108 auto l2_acts = torch::from_blob (
88109 math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(l2_token_buffer.base )),
89110 {num_experts_per_rank * num_max_recv_tokens_per_expert, intermediate_hidden},
90111 torch::TensorOptions ().dtype (torch::kFloat8_e4m3fn ).device (buffer.device ()));
91- return std::make_tuple (x, x_sf, topk_idx, topk_weights, l1_acts, l2_acts);
112+ auto l2_acts_sf = torch::from_blob (
113+ math::advance_ptr (buffer.data_ptr (), reinterpret_cast <int64_t >(l2_sf_buffer.base )),
114+ {num_max_recv_tokens_per_expert, intermediate_hidden / 128 * num_experts_per_rank},
115+ {1 , num_max_recv_tokens_per_expert},
116+ torch::TensorOptions ().dtype (torch::kInt ).device (buffer.device ()));
117+ return std::make_tuple (x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
92118 };
93119 return {reinterpret_cast <int64_t >(combine_token_buffer.get_end_ptr ()), slice_input_buffers};
94120}
@@ -133,6 +159,13 @@ static void fp8_fp4_mega_moe(
133159 DG_HOST_ASSERT (intermediate_hidden_2 == 2 * intermediate_hidden);
134160 DG_HOST_ASSERT (l1_weights.is_contiguous () and l2_weights.is_contiguous ());
135161
162+ // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
163+ constexpr int kGranMN = 1 , kGranK = 32 ;
164+ check_sf_layout (l1_weights_sf, intermediate_hidden * 2 , hidden, kGranMN , kGranK ,
165+ num_experts_per_rank, true , false , torch::kInt );
166+ check_sf_layout (l2_weights_sf, hidden, intermediate_hidden, kGranMN , kGranK ,
167+ num_experts_per_rank, true , false , torch::kInt );
168+
136169 // Check buffer bytes
137170 const auto num_ranks = static_cast <int >(sym_buffer_ptrs.size ());
138171 const auto num_experts_ = num_experts_per_rank * num_ranks;
@@ -145,12 +178,13 @@ static void fp8_fp4_mega_moe(
145178 DG_HOST_ASSERT (num_experts == num_experts_);
146179
147180 // Already registered tensors
148- const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l2_acts] = slice (sym_buffer);
181+ const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf ] = slice (sym_buffer);
149182
150183 // Dispatch into different architectures
151184 if (arch_major == 10 ) {
152185 sm100_fp8_fp4_mega_moe (y,
153- l1_acts, l2_acts,
186+ l1_acts, l1_acts_sf,
187+ l2_acts, l2_acts_sf,
154188 l1_weights, l2_weights,
155189 l1_weights_sf, l2_weights_sf,
156190 sym_buffer_ptrs,
0 commit comments