豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit b29302e

Browse files
authored
fix b200 cu128 (#4)
1 parent eda35b4 commit b29302e

File tree

7 files changed

+19
-17
lines changed

7 files changed

+19
-17
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ cat develop.sh
6666

6767
# Test all GEMM implements
6868
python tests/test_layout.py
69-
python tests/test_core.py
69+
python tests/test_bf16.py
70+
python tests/test_fp8.py
71+
python tests/test_lazy_init.py
7072
```
7173

7274
### Installation

csrc/jit/compiler.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class NVCCCompiler final: public Compiler {
155155
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
156156

157157
// The override the compiler flags
158-
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
158+
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
159159
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
160160
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
161161
flags, library_include_path.c_str(), device_runtime->get_arch());
@@ -205,7 +205,7 @@ class NVRTCCompiler final: public Compiler {
205205
}
206206

207207
// Override the compiler flags
208-
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
208+
flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}",
209209
flags, include_dirs, device_runtime->get_arch(), pch_flags);
210210
}
211211

csrc/jit/device_runtime.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ class DeviceRuntime {
2525
return {prop->major, prop->minor};
2626
}
2727

28-
std::string get_arch() {
28+
int get_arch() {
2929
const auto& [major, minor] = get_arch_pair();
30-
if (major == 10 and minor != 1)
31-
return "100f";
32-
return std::to_string(major * 10 + minor) + "a";
30+
return major * 10 + minor;
3331
}
3432

3533
int get_arch_major() {

csrc/jit/kernel_runtime.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class KernelRuntime final {
4646
std::istringstream iss(symbols);
4747
std::vector<std::string> symbol_names;
4848
for (std::string line; std::getline(iss, line); ) {
49-
if (line.find("STT_FUNC") == 0 and std::none_of(illegal_names.begin(), illegal_names.end(),
49+
if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and
50+
std::none_of(illegal_names.begin(), illegal_names.end(),
5051
[&](const auto& name) { return line.find(name) != std::string::npos; })) {
5152
const auto& last_space = line.rfind(' ');
5253
symbol_names.push_back(line.substr(last_space + 1));

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ sm100_bf16_gemm_impl(int* grouped_layout,
3232
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
3333
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
3434
using Barrier = cutlass::arch::ClusterTransactionBarrier;
35-
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
3635

3736
// GEMM with accumulation must have FP32 output
3837
if constexpr (kWithAccumulation)
@@ -142,7 +141,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
142141
cutlass::arch::fence_barrier_init();
143142
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
144143
// Allocate tensor memory
145-
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
144+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
146145
}
147146
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
148147

@@ -473,13 +472,15 @@ sm100_bf16_gemm_impl(int* grouped_layout,
473472
}
474473

475474
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
475+
// TODO: do we actually need this?
476476
if (epilogue_thread_idx == 0)
477477
cute::tma_store_wait<0>();
478478

479479
// Deallocate tensor memory by warp 1
480480
// NOTES: warp 0 is waiting TMA store
481+
// TODO: do we need 2 SM allocation?
481482
if (epilogue_warp_idx == 1)
482-
Allocator().free(0, kNumTmemCols);
483+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
483484
}
484485

485486
// To safely deconstruct all barriers, we need a cluster sync

deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
3333
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
3434
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
3535
using Barrier = cutlass::arch::ClusterTransactionBarrier;
36-
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
3736

3837
// GEMM with accumulation must have FP32 output
3938
if constexpr (kWithAccumulation)
@@ -170,7 +169,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
170169
cutlass::arch::fence_barrier_init();
171170
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
172171
// Allocate tensor memory
173-
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
172+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
174173
}
175174
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
176175

@@ -578,13 +577,15 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
578577
}
579578

580579
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
580+
// TODO: do we actually need this?
581581
if (epilogue_thread_idx == 0)
582582
cute::tma_store_wait<0>();
583583

584584
// Deallocate tensor memory by warp 1
585585
// NOTES: warp 0 is waiting TMA store
586+
// TODO: do we need 2 SM allocation?
586587
if (epilogue_warp_idx == 1)
587-
Allocator().free(0, kNumTmemCols);
588+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
588589
}
589590

590591
// To safely deconstruct all barriers, we need a cluster sync

deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
3232
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
3333
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
3434
using Barrier = cutlass::arch::ClusterTransactionBarrier;
35-
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
3635

3736
// Scaling checks
3837
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
@@ -153,7 +152,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
153152
cutlass::arch::fence_barrier_init();
154153
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
155154
// Allocate tensor memory
156-
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
155+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
157156
}
158157
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
159158

@@ -519,7 +518,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
519518
// Deallocate tensor memory by warp 1
520519
// NOTES: warp 0 is waiting TMA store
521520
if (epilogue_warp_idx == 1)
522-
Allocator().free(0, kNumTmemCols);
521+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
523522
}
524523

525524
// To safely deconstruct all barriers, we need a cluster sync

0 commit comments

Comments
 (0)