豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit ef5c7ee

Browse files
authored
Sync cluster before 2-CTA TMEM alloc (#192)
* Sync cluster before 2-CTA TMEM alloc * Minor fix * Minor fix
1 parent bd74376 commit ef5c7ee

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
9494
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols>();
9595
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
9696

97+
// Synchronize the cluster before 2-CTA TMEM allocation
98+
kNumMulticast > 1 ? cute::cluster_sync() : void();
99+
97100
// Utils
98101
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
99102
const auto warp_idx = cutlass::canonical_warp_idx_sync();
@@ -416,8 +419,10 @@ sm100_bf16_gemm_impl(int* grouped_layout,
416419
}
417420
}
418421

419-
// Deallocate tensor memory
422+
// TODO: Remove redundant synchronization
420423
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
424+
425+
// Deallocate tensor memory
421426
if (warp_idx == 0)
422427
Allocator().free(0, kNumTmemCols);
423428

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
9999
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
100100
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
101101

102+
// Synchronize the cluster before 2-CTA TMEM allocation
103+
kNumMulticast > 1 ? cute::cluster_sync() : void();
104+
102105
// Utils
103106
const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
104107
const auto warp_idx = cutlass::canonical_warp_idx_sync();
@@ -494,8 +497,10 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
494497
}
495498
}
496499

497-
// Deallocate tensor memory
500+
// TODO: Remove redundant synchronization
498501
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
502+
503+
// Deallocate tensor memory
499504
if (warp_idx == 0)
500505
Allocator().free(0, kNumTmemCols);
501506

0 commit comments

Comments
 (0)