豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 0f7aaa5

Browse files
committed
initial implementation
1 parent 4ff3f54 commit 0f7aaa5

File tree

5 files changed

+299
-217
lines changed

5 files changed

+299
-217
lines changed

csrc/apis/attention.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,15 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
121121
}
122122

123123
// Dispatch implementation
124+
torch::Tensor cu_seq_len_k_start_and_end = torch::stack({cu_seq_len_k_start, cu_seq_len_k_end}, 1).reshape({-1});
125+
cu_seq_len_k_start_and_end = cu_seq_len_k_start_and_end.contiguous();
124126
const auto& arch_major = device_runtime->get_arch_major();
125-
if (arch_major == 9 or arch_major == 10) {
126-
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
127+
if (arch_major == 9) {
128+
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start_and_end, logits,
129+
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment);
130+
} else if (arch_major == 10) {
131+
auto weights_fp16 = weights.to(torch::kFloat16).contiguous();
132+
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights_fp16, cu_seq_len_k_start_and_end, logits,
127133
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment);
128134
} else {
129135
DG_HOST_UNREACHABLE("Unsupported architecture");

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
6565
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
6666
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
6767
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
68+
case torch::kFloat16: return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
6869
default: DG_HOST_UNREACHABLE("Unsupported dtype");
6970
}
7071
}

csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ class SMXXFP8MQALogitsRuntime final: public LaunchRuntime<SMXXFP8MQALogitsRuntim
2424
int block_q;
2525
int block_kv;
2626

27-
int* cu_seq_len_k_start;
28-
int* cu_seq_len_k_end;
27+
int* cu_seq_len_k_start_and_end;
2928
float* logits;
3029
float softmax_scale;
3130

@@ -72,7 +71,7 @@ static void __instantiate_kernel() {{
7271
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
7372
args.seq_len, args.seq_len_kv,
7473
args.max_seqlen_k, static_cast<int64_t>(args.stride_logits),
75-
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
74+
args.cu_seq_len_k_start_and_end,
7675
args.logits,
7776
args.tensor_map_q, args.tensor_map_kv,
7877
args.tensor_map_kv_scales, args.tensor_map_weights
@@ -83,8 +82,7 @@ static void __instantiate_kernel() {{
8382
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
8483
const torch::Tensor& kv, const torch::Tensor& kv_scales,
8584
const torch::Tensor& weights,
86-
const torch::Tensor& cu_seq_len_k_start,
87-
const torch::Tensor& cu_seq_len_k_end,
85+
const torch::Tensor& cu_seq_len_k_start_and_end,
8886
const torch::Tensor& logits,
8987
const int& seq_len, const int& seq_len_kv,
9088
const int& max_seqlen_k, const int& stride_logits,
@@ -93,8 +91,15 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
9391
constexpr int block_qh = 128;
9492
constexpr int block_kv = 256;
9593
constexpr int num_specialized_threads = 128;
96-
constexpr int num_q_stages = 3, num_kv_stages = 3;
97-
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
94+
bool is_sm100 = device_runtime->get_arch_major() == 10;
95+
int num_q_stages = 3, num_kv_stages = 3;
96+
int num_splits = 1;
97+
if (is_sm100) {
98+
num_q_stages = 5;
99+
num_kv_stages = 8;
100+
num_splits = 2;
101+
}
102+
const int num_math_threads = (is_sm100 ? 256 : 512);
98103
const int block_q = block_qh / num_heads;
99104
DG_HOST_ASSERT(block_qh % num_heads == 0);
100105
DG_HOST_ASSERT(seq_len_alignment % block_q == 0);
@@ -107,27 +112,28 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
107112
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
108113
head_dim, block_qh, head_dim, head_dim);
109114
const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
110-
head_dim, block_kv, head_dim, head_dim);
115+
head_dim, block_kv / num_splits, head_dim, head_dim);
111116
// According to the driver API, the minimal alignment is 256 bytes
112117
// So it is safe for us to do a 16-byte OOB
113118
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
114119
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
115-
1, block_kv, 1, 0, 0);
120+
1, block_kv / num_splits, 1, 0, 0);
116121
const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
117122
num_heads, block_q, num_heads, 0);
118123

119124
// Calculate shared memory size
120125
int smem_size = 0;
121126
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
122-
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
123-
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
124-
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
127+
const int smem_weight_size_per_stage = num_splits * block_q * num_heads * static_cast<int>(weights.element_size());
128+
const int smem_kv_size_per_stage = (block_kv / num_splits) * head_dim * static_cast<int>(kv.element_size());
129+
const int kv_scale_size_per_stage = (block_kv / num_splits) * static_cast<int>(kv_scales.element_size());
125130
smem_size += num_q_stages * smem_q_size_per_stage;
126131
smem_size += num_kv_stages * smem_kv_size_per_stage;
127132
smem_size += num_q_stages * smem_weight_size_per_stage;
128133
smem_size += num_kv_stages * kv_scale_size_per_stage;
129-
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
130-
smem_size += 4;
134+
const int num_mma_stages = 2;
135+
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + num_mma_stages * 2) * 8;
136+
smem_size += 256;
131137
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
132138
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
133139

@@ -143,8 +149,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
143149
.num_kv_stages = num_kv_stages,
144150
.block_q = block_q,
145151
.block_kv = block_kv,
146-
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
147-
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
152+
.cu_seq_len_k_start_and_end = cu_seq_len_k_start_and_end.data_ptr<int>(),
148153
.logits = logits.data_ptr<float>(),
149154
.tensor_map_q = tensor_map_q,
150155
.tensor_map_kv = tensor_map_kv,
@@ -154,7 +159,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
154159
.num_math_threads = num_math_threads,
155160
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
156161
num_specialized_threads + num_math_threads,
157-
smem_size)
162+
smem_size, num_splits)
158163
};
159164
const auto& code = SMXXFP8MQALogitsRuntime::generate(args);
160165
const auto& runtime = compiler->build("smxx_fp8_mqa_logits", code);

0 commit comments

Comments
 (0)