@@ -24,8 +24,7 @@ class SMXXFP8MQALogitsRuntime final: public LaunchRuntime<SMXXFP8MQALogitsRuntim
2424 int block_q;
2525 int block_kv;
2626
27- int * cu_seq_len_k_start;
28- int * cu_seq_len_k_end;
27+ int * cu_seq_len_k_start_and_end;
2928 float * logits;
3029 float softmax_scale;
3130
@@ -72,7 +71,7 @@ static void __instantiate_kernel() {{
7271 DG_CUDA_UNIFIED_CHECK (launch_kernel (kernel, config,
7372 args.seq_len , args.seq_len_kv ,
7473 args.max_seqlen_k , static_cast <int64_t >(args.stride_logits ),
75- args.cu_seq_len_k_start , args. cu_seq_len_k_end ,
74+ args.cu_seq_len_k_start_and_end ,
7675 args.logits ,
7776 args.tensor_map_q , args.tensor_map_kv ,
7877 args.tensor_map_kv_scales , args.tensor_map_weights
@@ -83,8 +82,7 @@ static void __instantiate_kernel() {{
8382static void smxx_fp8_mqa_logits (const torch::Tensor& q,
8483 const torch::Tensor& kv, const torch::Tensor& kv_scales,
8584 const torch::Tensor& weights,
86- const torch::Tensor& cu_seq_len_k_start,
87- const torch::Tensor& cu_seq_len_k_end,
85+ const torch::Tensor& cu_seq_len_k_start_and_end,
8886 const torch::Tensor& logits,
8987 const int & seq_len, const int & seq_len_kv,
9088 const int & max_seqlen_k, const int & stride_logits,
@@ -93,8 +91,15 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
9391 constexpr int block_qh = 128 ;
9492 constexpr int block_kv = 256 ;
9593 constexpr int num_specialized_threads = 128 ;
96- constexpr int num_q_stages = 3 , num_kv_stages = 3 ;
97- const int num_math_threads = (device_runtime->get_arch_major () == 10 ? 256 : 512 );
94+ bool is_sm100 = device_runtime->get_arch_major () == 10 ;
95+ int num_q_stages = 3 , num_kv_stages = 3 ;
96+ int num_splits = 1 ;
97+ if (is_sm100) {
98+ num_q_stages = 5 ;
99+ num_kv_stages = 8 ;
100+ num_splits = 2 ;
101+ }
102+ const int num_math_threads = (is_sm100 ? 256 : 512 );
98103 const int block_q = block_qh / num_heads;
99104 DG_HOST_ASSERT (block_qh % num_heads == 0 );
100105 DG_HOST_ASSERT (seq_len_alignment % block_q == 0 );
@@ -107,27 +112,28 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
107112 const auto & tensor_map_q = make_tma_2d_desc (q, head_dim, seq_len * num_heads,
108113 head_dim, block_qh, head_dim, head_dim);
109114 const auto & tensor_map_kv = make_tma_2d_desc (kv, head_dim, seq_len_kv,
110- head_dim, block_kv, head_dim, head_dim);
115+ head_dim, block_kv / num_splits , head_dim, head_dim);
111116 // According to the driver API, the minimal alignment is 256 bytes
112117 // So it is safe for us to do a 16-byte OOB
113118 const auto & tensor_map_kv_scales = make_tma_2d_desc (kv_scales,
114119 get_tma_aligned_size (seq_len_kv, static_cast <int >(kv_scales.element_size ())),
115- 1 , block_kv, 1 , 0 , 0 );
120+ 1 , block_kv / num_splits , 1 , 0 , 0 );
116121 const auto & tensor_map_weights = make_tma_2d_desc (weights, num_heads, seq_len,
117122 num_heads, block_q, num_heads, 0 );
118123
119124 // Calculate shared memory size
120125 int smem_size = 0 ;
121126 const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast <int >(q.element_size ());
122- const int smem_weight_size_per_stage = block_q * num_heads * static_cast <int >(weights.element_size ());
123- const int smem_kv_size_per_stage = block_kv * head_dim * static_cast <int >(kv.element_size ());
124- const int kv_scale_size_per_stage = block_kv * static_cast <int >(kv_scales.element_size ());
127+ const int smem_weight_size_per_stage = num_splits * block_q * num_heads * static_cast <int >(weights.element_size ());
128+ const int smem_kv_size_per_stage = ( block_kv / num_splits) * head_dim * static_cast <int >(kv.element_size ());
129+ const int kv_scale_size_per_stage = ( block_kv / num_splits) * static_cast <int >(kv_scales.element_size ());
125130 smem_size += num_q_stages * smem_q_size_per_stage;
126131 smem_size += num_kv_stages * smem_kv_size_per_stage;
127132 smem_size += num_q_stages * smem_weight_size_per_stage;
128133 smem_size += num_kv_stages * kv_scale_size_per_stage;
129- smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128 ) * 2 ) * 8 ;
130- smem_size += 4 ;
134+ const int num_mma_stages = 2 ;
135+ smem_size += (num_q_stages * 2 + num_kv_stages * 2 + num_mma_stages * 2 ) * 8 ;
136+ smem_size += 256 ;
131137 DG_HOST_ASSERT (smem_size <= SM90ArchSpec::smem_capacity);
132138 DG_HOST_ASSERT (smem_size <= SM100ArchSpec::smem_capacity);
133139
@@ -143,8 +149,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
143149 .num_kv_stages = num_kv_stages,
144150 .block_q = block_q,
145151 .block_kv = block_kv,
146- .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr <int >(),
147- .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr <int >(),
152+ .cu_seq_len_k_start_and_end = cu_seq_len_k_start_and_end.data_ptr <int >(),
148153 .logits = logits.data_ptr <float >(),
149154 .tensor_map_q = tensor_map_q,
150155 .tensor_map_kv = tensor_map_kv,
@@ -154,7 +159,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
154159 .num_math_threads = num_math_threads,
155160 .launch_args = LaunchArgs (device_runtime->get_num_sms (),
156161 num_specialized_threads + num_math_threads,
157- smem_size)
162+ smem_size, num_splits )
158163 };
159164 const auto & code = SMXXFP8MQALogitsRuntime::generate (args);
160165 const auto & runtime = compiler->build (" smxx_fp8_mqa_logits" , code);
0 commit comments