豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit d30fc36

Browse files
authored
Fix sync issue of TMEM alloc/dealloc (#292)
1 parent 35c4bc8 commit d30fc36

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
132132
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
133133
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
134134

135+
if (kNumMulticast > 1)
136+
cute::cluster_sync();
137+
135138
// Initialize barriers
136139
if (warp_idx == 1 and cute::elect_one_sync()) {
137140
#pragma unroll
@@ -465,12 +468,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
465468
}
466469
}
467470
}
468-
469-
// Deallocate tensor memory by the last UMMA store warp
470-
// NOTES: warp 0 is waiting TMA store
471-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
472-
Allocator().free(0, kNumTmemCols);
473471
}
472+
473+
// Deallocate tensor memory
474+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
475+
if (warp_idx == 0)
476+
Allocator().free(0, kNumTmemCols);
477+
474478
#else
475479
if (blockIdx.x == 0 and threadIdx.x == 0)
476480
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");

deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
251251
}
252252
}
253253

254+
__syncthreads();
254255
// Deallocate tensor memory by warp 1
255256
// NOTES: warp 0 is doing TMA stores
256257
if (warp_idx == 1)

deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
155155
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
156156
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
157157

158+
if (kNumMulticast > 1)
159+
cute::cluster_sync();
160+
158161
// Initialize barriers
159162
if (warp_idx == 1 and cute::elect_one_sync()) {
160163
#pragma unroll
@@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
546549
}
547550
}
548551
}
549-
550-
// Deallocate tensor memory by the last UMMA store warp
551-
// NOTES: warp 0 is waiting TMA store
552-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
553-
Allocator().free(0, kNumTmemCols);
554552
}
553+
554+
// Deallocate tensor memory
555+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
556+
if (warp_idx == 0)
557+
Allocator().free(0, kNumTmemCols);
558+
555559
#else
556560
if (blockIdx.x == 0 and threadIdx.x == 0)
557561
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");

0 commit comments

Comments
 (0)