豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit bd74376

Browse files
xay5421zheanxuLyricZhao
authored
PDL (#172)
* Add PDL for sm90 gemm * Remove __ldg * Fix PDL * Add PDL for sm100 gemm * Add explicit cudaTriggerProgrammaticLaunchCompletion * Minor Fix * remove __ldg * Minor fix * Add enable_pdl in launch config * Set enable_pdl for gemm * Add DG_DISABLE_PDL * Minor fix * Set enable_pdl for smxx_layout * Add PDL to all kernels except backward kernels * Remove cudaTriggerProgrammaticLaunchCompletion * Minor fix * Minor fix * Minor fix * Minor fix * Remove useless PDL * revert some change * revert some changes * Minor fix * Refactor pointer arithmetic to array indexing * Add grid-dependency sync and remove __ldg in kernels * Minor fix * Add runtime API PDL toggle * Minor fix * Drop explicit LaunchArgs enable_pdl arguments * Simplify LaunchArgs construction * Update comments * Minor fix * Minor fix * Update device_runtime.hpp --------- Co-authored-by: Zhean Xu <xza@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent 0df67dc commit bd74376

30 files changed

+173
-58
lines changed

csrc/apis/runtime.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ static void register_apis(pybind11::module_& m) {
2121
m.def("get_tc_util", [&]() {
2222
return device_runtime->get_tc_util();
2323
});
24+
m.def("set_pdl", [&](const bool& new_enable_pdl) {
25+
device_runtime->set_pdl(new_enable_pdl);
26+
});
27+
m.def("get_pdl", [&]() {
28+
return device_runtime->get_pdl();
29+
});
2430
m.def("set_ignore_compile_dims", [&](const bool& new_value) {
2531
heuristics_runtime->set_ignore_compile_dims(new_value);
2632
});

csrc/jit/device_runtime.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace deep_gemm {
1313

1414
class DeviceRuntime {
1515
int num_sms = 0, tc_util = 0;
16+
bool enable_pdl = false;
1617
std::shared_ptr<cudaDeviceProp> cached_prop;
1718

1819
// cuBLASLt utils
@@ -114,6 +115,14 @@ class DeviceRuntime {
114115
int get_tc_util() const {
115116
return tc_util == 0 ? 100 : tc_util;
116117
}
118+
119+
void set_pdl(const bool& new_enable_pdl) {
120+
enable_pdl = new_enable_pdl;
121+
}
122+
123+
bool get_pdl() const {
124+
return enable_pdl;
125+
}
117126
};
118127

119128
static auto device_runtime = LazyInit<DeviceRuntime>([](){ return std::make_shared<DeviceRuntime>(); });

csrc/jit/handle.hpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ static void unload_library(const LibraryHandle& library) {
7474

7575
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
7676
const cudaStream_t& stream, const int& smem_size,
77-
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
77+
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
7878
if (smem_size > 0)
7979
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
8080

@@ -83,17 +83,27 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
8383
config.blockDim = block_dim;
8484
config.dynamicSmemBytes = smem_size;
8585
config.stream = stream;
86-
config.numAttrs = 0;
87-
config.attrs = nullptr;
8886

87+
// Create attributes
8988
// NOTES: must use `static` or the `attr` will be deconstructed
90-
static LaunchAttrHandle attr;
89+
static LaunchAttrHandle attrs[2];
90+
config.numAttrs = 0;
91+
config.attrs = attrs;
92+
93+
// Cluster size
9194
if (cluster_dim > 1) {
95+
auto& attr = attrs[config.numAttrs ++];
9296
attr.id = cudaLaunchAttributeClusterDimension;
9397
attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
94-
config.attrs = &attr;
95-
config.numAttrs = 1;
9698
}
99+
100+
// Dependent kernel launch
101+
if (enable_pdl) {
102+
auto& attr = attrs[config.numAttrs ++];
103+
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
104+
attr.val.programmaticStreamSerializationAllowed = 1;
105+
}
106+
97107
return config;
98108
}
99109

@@ -155,8 +165,8 @@ static void unload_library(const LibraryHandle& library) {
155165
}
156166

157167
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
158-
const cudaStream_t& stream, const int& smem_size,
159-
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
168+
const cudaStream_t& stream, const int& smem_size,
169+
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
160170
if (smem_size > 0)
161171
DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
162172

@@ -169,19 +179,29 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
169179
config.blockDimZ = block_dim.z;
170180
config.sharedMemBytes = smem_size;
171181
config.hStream = stream;
182+
183+
// Create attributes
184+
// NOTES: must use `static` or the `attr` will be deconstructed
185+
static LaunchAttrHandle attrs[2];
172186
config.numAttrs = 0;
173-
config.attrs = nullptr;
187+
config.attrs = attrs;
174188

175-
// NOTES: must use `static` or the `attr` will be deconstructed
176-
static LaunchAttrHandle attr;
189+
// Cluster size
177190
if (cluster_dim > 1) {
191+
auto& attr = attrs[config.numAttrs ++];
178192
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
179-
attr.value.clusterDim.x = cluster_dim;
193+
attr.value.clusterDim.x = static_cast<unsigned>(cluster_dim);
180194
attr.value.clusterDim.y = 1;
181195
attr.value.clusterDim.z = 1;
182-
config.attrs = &attr;
183-
config.numAttrs = 1;
184196
}
197+
198+
// Dependent kernel launch
199+
if (enable_pdl) {
200+
auto& attr = attrs[config.numAttrs ++];
201+
attr.id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION;
202+
attr.value.programmaticStreamSerializationAllowed = 1;
203+
}
204+
185205
return config;
186206
}
187207

csrc/jit/kernel_runtime.hpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ struct LaunchArgs {
1616
int num_threads;
1717
int smem_size;
1818
int cluster_dim;
19+
bool enable_pdl;
1920

20-
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
21-
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
21+
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
22+
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
2223

23-
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
24-
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
24+
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
25+
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
2526
};
2627

2728
class KernelRuntime final {
@@ -127,20 +128,24 @@ class LaunchRuntime {
127128
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
128129
const auto kernel = kernel_runtime->kernel;
129130
const auto stream = at::cuda::getCurrentCUDAStream();
130-
const LaunchArgs launch_args = args.launch_args;
131+
LaunchArgs launch_args = args.launch_args;
132+
133+
// Allow runtime override from Python.
134+
// NOTES: the default is enabled.
135+
launch_args.enable_pdl = device_runtime->get_pdl();
131136

132137
const dim3 grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first),
133138
static_cast<unsigned>(launch_args.grid_dim.second),
134139
1};
135140
const dim3 block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
136141
auto config = construct_launch_config(kernel, stream, launch_args.smem_size,
137-
grid_dim, block_dim, launch_args.cluster_dim);
142+
grid_dim, block_dim, launch_args.cluster_dim, launch_args.enable_pdl);
138143

139144
// Launch in the derived class
140145
if (get_env<int>("DG_JIT_DEBUG")) {
141-
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n",
146+
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, enable_pdl: %d, stream: %ld\n",
142147
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
143-
launch_args.smem_size, launch_args.cluster_dim, stream.id());
148+
launch_args.smem_size, launch_args.cluster_dim, launch_args.enable_pdl, stream.id());
144149
}
145150
Derived::launch_impl(kernel, config, args);
146151
}

csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
135135
.num_stages = num_stages,
136136
.num_mma_threads = num_mma_threads,
137137
.num_cast_and_reduce_threads = num_cast_and_reduce_threads,
138-
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1),
138+
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size),
139139
.tensor_map_a = tensor_map_a,
140140
.tensor_map_b = tensor_map_b,
141141
.tensor_map_d = tensor_map_d,

csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
138138
.num_stages = num_stages,
139139
.num_math_threads = num_math_threads,
140140
.num_tma_threads = num_tma_threads,
141-
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1),
141+
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size),
142142
.tensor_map_a = tensor_map_a,
143143
.tensor_map_b = tensor_map_b,
144144
.tensor_map_d = tensor_map_d,

deep_gemm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
get_tc_util,
2222
set_ignore_compile_dims,
2323
set_block_size_multiple_of,
24+
set_pdl,
25+
get_pdl,
2426
)
2527

2628
# cuBLASLt Kernels

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
164164
}
165165
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
166166

167+
// Wait for primary kernel completion
168+
cudaGridDependencySynchronize();
169+
167170
// Block scheduler
168171
uint32_t m_block_idx, n_block_idx;
169172
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(

deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
102102
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
103103
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
104104

105+
// Wait for primary kernel completion
106+
cudaGridDependencySynchronize();
107+
105108
if (warp_idx == 0) {
106109
// TMA load warp
107110
for (uint32_t s = 0; s < num_total_stages; ++ s) {

deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
157157
#pragma unroll
158158
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
159159
const auto row_idx = min(q_idx * BLOCK_Q + i, seq_len - 1);
160-
seq_k_start[i] = min(__ldg(cu_seq_len_k_start + row_idx), seq_len_kv);
161-
seq_k_end[i] = min(__ldg(cu_seq_len_k_end + row_idx), seq_len_kv);
160+
seq_k_start[i] = min(cu_seq_len_k_start[row_idx], seq_len_kv);
161+
seq_k_end[i] = min(cu_seq_len_k_end[row_idx], seq_len_kv);
162162
start = min(start, seq_k_start[i]);
163163
end = max(end, seq_k_end[i]);
164164
}
@@ -184,6 +184,9 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
184184
constexpr uint32_t kNumSpecializedRegisters = 40;
185185
constexpr uint32_t kNumMathRegisters = 232;
186186

187+
// Wait for primary kernel completion
188+
cudaGridDependencySynchronize();
189+
187190
if (is_tma_q_warp) {
188191
// TMA warp for loading Q
189192
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();

0 commit comments

Comments
 (0)