豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 0df67dc

Browse files
authored
Fix 2CTA TMEM Free (#189)
* Fix 2CTA TMEM * Fix BF16 as well
1 parent e45e35d commit 0df67dc

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
411411
tensor_map_cd);
412412
}
413413
}
414-
415-
// Deallocate tensor memory by the last UMMA store warp
416-
// NOTES: warp 0 is waiting TMA store
417-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
418-
Allocator().free(0, kNumTmemCols);
419414
}
415+
416+
// Deallocate tensor memory
417+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
418+
if (warp_idx == 0)
419+
Allocator().free(0, kNumTmemCols);
420+
420421
#else
421422
if (blockIdx.x == 0 and threadIdx.x == 0)
422423
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,13 @@ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
489489
tensor_map_cd);
490490
}
491491
}
492-
493-
// Deallocate tensor memory by the last UMMA store warp
494-
// NOTES: warp 0 is waiting TMA store
495-
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
496-
Allocator().free(0, kNumTmemCols);
497492
}
493+
494+
// Deallocate tensor memory
495+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
496+
if (warp_idx == 0)
497+
Allocator().free(0, kNumTmemCols);
498+
498499
#else
499500
if (blockIdx.x == 0 and threadIdx.x == 0)
500501
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");

0 commit comments

Comments
 (0)